diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 7db4629708..47eb9cd10e 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -372,6 +372,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, Erf); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, Erf); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Erf); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, bool, Not); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, BatchNormalization); @@ -709,6 +710,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index 323a7ad997..af58f32c9c 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -23,6 +23,17 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ x); +#define UNARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(x, ver, T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + x, \ + kOnnxDomain, \ + ver, \ + T, \ + kCudaExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ + x); + #define UNARY_ELEMENTWISE_COMPUTE(x, T) \ template <> \ Status x::ComputeInternal(OpKernelContext* context) const { \ @@ -40,6 +51,10 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa UNARY_ELEMENTWISE_REGISTER_KERNEL(name, ver, T) \ UNARY_ELEMENTWISE_COMPUTE(name, T) +#define UNARY_LOGICALOP_TYPED(name, ver, T) \ + UNARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, T) \ + UNARY_ELEMENTWISE_COMPUTE(name, T) + // the postfix of means the types supported by the op: // B: uint8_t // W: uint16_t @@ -82,6 +97,7 @@ UNARY_OP_HFD(Sqrt, 6) UNARY_OP_HFD(Log, 6) UNARY_OP_HFD(Exp, 6) UNARY_OP_HFD(Erf, 9) +UNARY_LOGICALOP_TYPED(Not, 1, bool) } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h index 8ae3753ba4..1bd6d2160c 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h @@ -84,5 +84,12 @@ class Erf final : public UnaryElementwise { Status ComputeInternal(OpKernelContext* context) const override; }; +template +class Not final : public UnaryElementwise { + public: + Not(const OpKernelInfo& info) : UnaryElementwise(info) {} + Status ComputeInternal(OpKernelContext* context) const override; +}; + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index 4049cc1e90..d1a006227a 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -77,6 +77,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sqrt) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Log) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Exp) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf) +SPECIALIZED_UNARY_ELEMENTWISE_IMPL(Not, bool) // When casting, half needs to be converted via float type from most other types template diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h index daaa3a5304..c22f597507 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h @@ -22,7 +22,8 @@ namespace cuda { UNARY_OP_NAME_EXPR(Sqrt, _Sqrt(a)) \ UNARY_OP_NAME_EXPR(Exp, _Exp(a)) \ UNARY_OP_NAME_EXPR(Log, _Log(a)) \ - UNARY_OP_NAME_EXPR(Erf, _Erf(a)) + UNARY_OP_NAME_EXPR(Erf, _Erf(a)) \ + UNARY_OP_NAME_EXPR(Not, !a) #define UNARY_ELEMENTWISE_IMPL_DECLARATION(name) \ template \