Support Round op in the CUDA EP (#2601)

* Support Round op for the CUDA EP

* Update version

* Fix build

* Fix opset version

* Update

* PR comments

* Fix build

* Nit
This commit is contained in:
Hariharan Seshadri 2019-12-19 11:36:50 -08:00 committed by GitHub
parent 715e365723
commit b3d0b114fe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 44 additions and 11 deletions

View file

@ -119,6 +119,24 @@ __device__ __inline__ double _Erf(double a) { return erf(a); }
template <>
__device__ __inline__ half _Erf(half a) { return half(erff((float)a)); }
template <typename T>
__device__ __inline__ T _Round(T a);
template <>
__device__ __inline__ float _Round(float a) { return rintf(a); }
template <>
__device__ __inline__ double _Round(double a) { return rint(a); }
template <>
__device__ __inline__ half _Round(half a) {
#if __CUDA_ARCH__ < 530
return half(rintf((float)a));
#else
return hrint(a);
#endif
}
template <typename T>
__device__ __inline__ T _Exp(T a);

View file

@ -667,6 +667,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, bool, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int32_t, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int64_t, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Round);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Round);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Round);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, CumSum);
static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
@ -1129,6 +1132,9 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, bool, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int32_t, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int64_t, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Round)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Round)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Round)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, CumSum)>,
};

View file

@ -23,15 +23,14 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
x<T>);
#define UNARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(x, ver, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
ver, \
T, \
kCudaExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()), \
#define UNARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(x, ver, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
ver, \
T, \
kCudaExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<T>()).TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()), \
x<T>);
#define UNARY_ELEMENTWISE_COMPUTE(x, T) \
@ -51,7 +50,7 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa
UNARY_ELEMENTWISE_REGISTER_KERNEL(name, ver, T) \
UNARY_ELEMENTWISE_COMPUTE(name, T)
#define UNARY_LOGICALOP_TYPED(name, ver, T) \
#define UNARY_LOGICALOP_TYPED(name, ver, T) \
UNARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, T) \
UNARY_ELEMENTWISE_COMPUTE(name, T)
@ -98,6 +97,7 @@ UNARY_OP_HFD(Log, 6)
UNARY_OP_HFD(Exp, 6)
UNARY_OP_HFD(Erf, 9)
UNARY_LOGICALOP_TYPED(Not, 1, bool)
UNARY_OP_HFD(Round, 11)
} // namespace cuda
} // namespace onnxruntime

View file

@ -91,5 +91,12 @@ class Not final : public UnaryElementwise {
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Round final : public UnaryElementwise {
public:
Round(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
} // namespace cuda
} // namespace onnxruntime

View file

@ -77,6 +77,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sqrt)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Log)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Exp)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(Not, bool)
// When casting, half needs to be converted via float type from most other types

View file

@ -23,7 +23,8 @@ namespace cuda {
UNARY_OP_NAME_EXPR(Exp, _Exp(a)) \
UNARY_OP_NAME_EXPR(Log, _Log(a)) \
UNARY_OP_NAME_EXPR(Erf, _Erf(a)) \
UNARY_OP_NAME_EXPR(Not, !a)
UNARY_OP_NAME_EXPR(Not, !a) \
UNARY_OP_NAME_EXPR(Round, _Round(a))
#define UNARY_ELEMENTWISE_IMPL_DECLARATION(name) \
template <typename T> \