diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 4a5faceb1e..6abbadc2d2 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -347,6 +347,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, Greater); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, Greater); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Greater); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, bool, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int32_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int64_t, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int32_t, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int64_t, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, uint32_t, Greater); @@ -675,6 +678,9 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc index c62a2b7a63..1aee5bff52 100644 --- a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc @@ -154,6 +154,11 @@ Status BinaryElementwise::Prepare(OpKernelContext* context, int BINARY_OP_TYPED(name, ver, int64_t) \ BINARY_OP_HFD(name, ver) +#define BINARY_OP_REGISTER_OIL(name, ver) \ + BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, bool) \ + BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, int32_t) \ + BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, int64_t) + #define BINARY_OP_REGISTER_HFD(name, ver) \ BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, MLFloat16) \ BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, float) \ @@ -397,9 +402,46 @@ Status Greater::ComputeInternal(OpKernelContext* context) const { return Status::OK(); } +template +Status Equal::ComputeInternal(OpKernelContext* context) const { + typedef typename ToCudaType::MappedType CudaT; + const onnxruntime::Node& node = OpKernel::Node(); + const std::string& name = node.Name(); + + const Tensor* input0 = context->Input(0); + const Tensor* input1 = context->Input(1); + TensorShape output_shape; + ORT_RETURN_IF_ERROR(ComputeOutputShape(name, input0->Shape(), input1->Shape(), output_shape)); + size_t output_size = output_shape.Size(); + Tensor* output_tensor = context->Output(0, output_shape); + + BinaryElementwisePreparation prepare(this); + ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(0, input0, input1, output_tensor, &prepare)); + + IAllocatorUniquePtr output_buffer = GetScratchBuffer(output_size); + Impl_Equal( + prepare.output_rank_or_simple_broadcast, + prepare.lhs_padded_strides.GpuPtr(), + reinterpret_cast(prepare.lhs_tensor->template Data()), + prepare.rhs_padded_strides.GpuPtr(), + reinterpret_cast(prepare.rhs_tensor->template Data()), + prepare.fdm_output_strides.GpuPtr(), + prepare.fdm_H, + prepare.fdm_C, + reinterpret_cast(output_buffer.get()), + output_size); + + Impl_Cast::MappedType>( + reinterpret_cast(output_buffer.get()), + reinterpret_cast::MappedType*>(output_tensor->template MutableData()), + output_size); + return Status::OK(); +} + BINARY_OP_REGISTER_UZILHFD(Sum, 8) BINARY_OP_REGISTER_VERSIONED_UZILHFD(Sum, 6, 7) BINARY_OP_REGISTER_UZILHFD(Greater, 9) +BINARY_OP_REGISTER_OIL(Equal, 7) BINARY_OP_REGISTER_VERSIONED_HFD(Greater, 7, 8) BINARY_OP_REGISTER_HFD(Max, 8) BINARY_OP_REGISTER_VERSIONED_HFD(Max, 6, 7) diff --git a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.h index 746640b1d9..2489edf3e8 100644 --- a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.h +++ b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.h @@ -204,6 +204,15 @@ class Greater final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; }; +template +class Equal final : public CudaKernel { + public: + Equal(const OpKernelInfo& info) : CudaKernel(info) {} + + Status ComputeInternal(OpKernelContext* context) const override; +}; + + template class Max final : public CudaKernel { public: diff --git a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.cu index d025fb0340..3260fdcf42 100644 --- a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.cu @@ -44,6 +44,11 @@ namespace cuda { SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \ SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) +#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_OIL(x) \ + SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, bool) \ + SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int32_t) \ + SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int64_t) + #define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(x) \ SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \ SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \ @@ -82,6 +87,7 @@ SPECIALIZED_BINARY_ELEMENTWISE_IMPL(Or, bool) SPECIALIZED_BINARY_ELEMENTWISE_IMPL(Xor, bool) SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(PRelu) SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Greater) +SPECIALIZED_BINARY_ELEMENTWISE_IMPL_OIL(Equal) SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(Max) SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(Min) diff --git a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.h index 2220ec7a2f..cacc12f859 100644 --- a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.h +++ b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.h @@ -26,6 +26,7 @@ namespace cuda { BINARY_OP_NAME_EXPR(Xor, (a ^ b)) \ BINARY_OP_NAME_EXPR(PRelu, (a > (T)0 ? a : a * b)) \ BINARY_OP_NAME_EXPR(Greater, (a > b) ? 1 : 0) \ + BINARY_OP_NAME_EXPR(Equal, ((a == b) ? 1 : 0)) \ BINARY_OP_NAME_EXPR(Max, _Max(a, b)) \ BINARY_OP_NAME_EXPR(Min, _Min(a, b))