Add Cuda Kernel for Not operator (#1801)

* Add Cuda Kernel for Not operator
* Register Not CUDA Kernel
This commit is contained in:
Chi Lo 2019-09-11 14:30:44 -07:00 committed by GitHub
parent a9e4de2cea
commit d9fa632863
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 28 additions and 1 deletions

View file

@ -372,6 +372,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, Erf);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, Erf);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Erf);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, bool, Not);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, BatchNormalization);
@ -709,6 +710,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, Erf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, Erf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Erf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, bool, Not)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, BatchNormalization)>,

View file

@ -23,6 +23,17 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
x<T>);
#define UNARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(x, ver, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
ver, \
T, \
kCudaExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()), \
x<T>);
#define UNARY_ELEMENTWISE_COMPUTE(x, T) \
template <> \
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
@ -40,6 +51,10 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa
UNARY_ELEMENTWISE_REGISTER_KERNEL(name, ver, T) \
UNARY_ELEMENTWISE_COMPUTE(name, T)
#define UNARY_LOGICALOP_TYPED(name, ver, T) \
UNARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, T) \
UNARY_ELEMENTWISE_COMPUTE(name, T)
// the postfix of means the types supported by the op:
// B: uint8_t
// W: uint16_t
@ -82,6 +97,7 @@ UNARY_OP_HFD(Sqrt, 6)
UNARY_OP_HFD(Log, 6)
UNARY_OP_HFD(Exp, 6)
UNARY_OP_HFD(Erf, 9)
UNARY_LOGICALOP_TYPED(Not, 1, bool)
} // namespace cuda
} // namespace onnxruntime

View file

@ -84,5 +84,12 @@ class Erf final : public UnaryElementwise {
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Not final : public UnaryElementwise {
public:
Not(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
} // namespace cuda
} // namespace onnxruntime

View file

@ -77,6 +77,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sqrt)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Log)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Exp)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(Not, bool)
// When casting, half needs to be converted via float type from most other types
template <typename T>

View file

@ -22,7 +22,8 @@ namespace cuda {
UNARY_OP_NAME_EXPR(Sqrt, _Sqrt(a)) \
UNARY_OP_NAME_EXPR(Exp, _Exp(a)) \
UNARY_OP_NAME_EXPR(Log, _Log(a)) \
UNARY_OP_NAME_EXPR(Erf, _Erf(a))
UNARY_OP_NAME_EXPR(Erf, _Erf(a)) \
UNARY_OP_NAME_EXPR(Not, !a)
#define UNARY_ELEMENTWISE_IMPL_DECLARATION(name) \
template <typename T> \