mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
Add Cuda Kernel for Not operator (#1801)
* Add Cuda Kernel for Not operator * Register Not CUDA Kernel
This commit is contained in:
parent
a9e4de2cea
commit
d9fa632863
5 changed files with 28 additions and 1 deletions
|
|
@ -372,6 +372,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, Erf);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, Erf);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Erf);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, bool, Not);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, BatchNormalization);
|
||||
|
|
@ -709,6 +710,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, Erf)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, Erf)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Erf)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, bool, Not)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, BatchNormalization)>,
|
||||
|
|
|
|||
|
|
@ -23,6 +23,17 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa
|
|||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
x<T>);
|
||||
|
||||
#define UNARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(x, ver, T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
x, \
|
||||
kOnnxDomain, \
|
||||
ver, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()), \
|
||||
x<T>);
|
||||
|
||||
#define UNARY_ELEMENTWISE_COMPUTE(x, T) \
|
||||
template <> \
|
||||
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
|
||||
|
|
@ -40,6 +51,10 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa
|
|||
UNARY_ELEMENTWISE_REGISTER_KERNEL(name, ver, T) \
|
||||
UNARY_ELEMENTWISE_COMPUTE(name, T)
|
||||
|
||||
#define UNARY_LOGICALOP_TYPED(name, ver, T) \
|
||||
UNARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, T) \
|
||||
UNARY_ELEMENTWISE_COMPUTE(name, T)
|
||||
|
||||
// the postfix of means the types supported by the op:
|
||||
// B: uint8_t
|
||||
// W: uint16_t
|
||||
|
|
@ -82,6 +97,7 @@ UNARY_OP_HFD(Sqrt, 6)
|
|||
UNARY_OP_HFD(Log, 6)
|
||||
UNARY_OP_HFD(Exp, 6)
|
||||
UNARY_OP_HFD(Erf, 9)
|
||||
UNARY_LOGICALOP_TYPED(Not, 1, bool)
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -84,5 +84,12 @@ class Erf final : public UnaryElementwise {
|
|||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Not final : public UnaryElementwise {
|
||||
public:
|
||||
Not(const OpKernelInfo& info) : UnaryElementwise(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sqrt)
|
|||
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Log)
|
||||
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Exp)
|
||||
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf)
|
||||
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(Not, bool)
|
||||
|
||||
// When casting, half needs to be converted via float type from most other types
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -22,7 +22,8 @@ namespace cuda {
|
|||
UNARY_OP_NAME_EXPR(Sqrt, _Sqrt(a)) \
|
||||
UNARY_OP_NAME_EXPR(Exp, _Exp(a)) \
|
||||
UNARY_OP_NAME_EXPR(Log, _Log(a)) \
|
||||
UNARY_OP_NAME_EXPR(Erf, _Erf(a))
|
||||
UNARY_OP_NAME_EXPR(Erf, _Erf(a)) \
|
||||
UNARY_OP_NAME_EXPR(Not, !a)
|
||||
|
||||
#define UNARY_ELEMENTWISE_IMPL_DECLARATION(name) \
|
||||
template <typename T> \
|
||||
|
|
|
|||
Loading…
Reference in a new issue