From ac5b5e5d1e4f27cd7fa541cc1ca6d16e36db4949 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Tue, 12 Jan 2021 10:46:21 +0800 Subject: [PATCH] more dtype for Equal CUDA kernel (#6288) Co-authored-by: Vincent Wang --- .../providers/cuda/cuda_execution_provider.cc | 20 +++++++++++++++++++ .../cuda/math/binary_elementwise_ops.cc | 11 ++++------ .../cuda/math/binary_elementwise_ops_impl.cu | 5 +---- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index c8cdf274f4..e1a6b339a6 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -746,6 +746,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, bool, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int64_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, uint32_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, uint64_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Round); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Round); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Round); @@ -868,6 +873,11 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, M class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint32_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint64_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint32_t, Greater); @@ -1455,6 +1465,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1573,6 +1588,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc index b389c49cd4..f8336dd4db 100644 --- a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc @@ -212,11 +212,6 @@ Status BinaryElementwise::Prepare(OpKernelContext* context, Bin BINARY_OP_TYPED(name, ver, int64_t) \ BINARY_OP_HFD(name, ver) -#define BINARY_OP_REGISTER_OIL(name, ver) \ - BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, bool) \ - BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, int32_t) \ - BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, int64_t) - #define BINARY_OP_REGISTER_VERSIONED_OIL(name, startver, endver) \ BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, bool) \ BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, int32_t) \ @@ -456,8 +451,10 @@ Status Less::ComputeInternal(OpKernelContext* context) const { return Status::OK(); } -BINARY_OP_REGISTER_OIL(Equal, 13) -BINARY_OP_REGISTER_VERSIONED_OIL(Equal, 11, 12) +BINARY_LOGICALOP_REGISTER_UZILHFD(Equal, 13) +BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(Equal, 13, bool) +BINARY_OP_REGISTER_VERSIONED_UZILHFD(Equal, 11, 12) +BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(Equal, 11, 12, bool) BINARY_OP_REGISTER_VERSIONED_OIL(Equal, 7, 10) BINARY_LOGICALOP_REGISTER_UZILHFD(Greater, 13) BINARY_OP_REGISTER_VERSIONED_UZILHFD(Greater, 9, 12) diff --git a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.cu index e4cab89128..6fc6d7dd1f 100644 --- a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.cu @@ -160,11 +160,8 @@ BINARY_OPS2() SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(name, bool, double, double) SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD2(Greater) - +SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD2(Equal) SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(Equal, bool, bool, bool) -SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(Equal, bool, int32_t, int32_t) -SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(Equal, bool, int64_t, int64_t) - SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD2(Less) } // namespace cuda