mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-19 21:32:23 +00:00
Implement Equal for CUDA. (#1183)
This commit is contained in:
parent
d33dbb23b2
commit
e43e64bf84
5 changed files with 64 additions and 0 deletions
|
|
@ -347,6 +347,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, Greater);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, Greater);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Greater);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, bool, Equal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int32_t, Equal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int64_t, Equal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int32_t, Greater);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int64_t, Greater);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, uint32_t, Greater);
|
||||
|
|
@ -675,6 +678,9 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, Greater)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, Greater)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Greater)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, bool, Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int32_t, Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int64_t, Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int32_t, Greater)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int64_t, Greater)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, uint32_t, Greater)>,
|
||||
|
|
|
|||
|
|
@ -154,6 +154,11 @@ Status BinaryElementwise<ShouldBroadcast>::Prepare(OpKernelContext* context, int
|
|||
BINARY_OP_TYPED(name, ver, int64_t) \
|
||||
BINARY_OP_HFD(name, ver)
|
||||
|
||||
#define BINARY_OP_REGISTER_OIL(name, ver) \
|
||||
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, bool) \
|
||||
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, int32_t) \
|
||||
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, int64_t)
|
||||
|
||||
#define BINARY_OP_REGISTER_HFD(name, ver) \
|
||||
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, MLFloat16) \
|
||||
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, float) \
|
||||
|
|
@ -397,9 +402,46 @@ Status Greater<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status Equal<T>::ComputeInternal(OpKernelContext* context) const {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
const onnxruntime::Node& node = OpKernel::Node();
|
||||
const std::string& name = node.Name();
|
||||
|
||||
const Tensor* input0 = context->Input<Tensor>(0);
|
||||
const Tensor* input1 = context->Input<Tensor>(1);
|
||||
TensorShape output_shape;
|
||||
ORT_RETURN_IF_ERROR(ComputeOutputShape(name, input0->Shape(), input1->Shape(), output_shape));
|
||||
size_t output_size = output_shape.Size();
|
||||
Tensor* output_tensor = context->Output(0, output_shape);
|
||||
|
||||
BinaryElementwisePreparation prepare(this);
|
||||
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(0, input0, input1, output_tensor, &prepare));
|
||||
|
||||
IAllocatorUniquePtr<T> output_buffer = GetScratchBuffer<T>(output_size);
|
||||
Impl_Equal<CudaT>(
|
||||
prepare.output_rank_or_simple_broadcast,
|
||||
prepare.lhs_padded_strides.GpuPtr(),
|
||||
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()),
|
||||
prepare.rhs_padded_strides.GpuPtr(),
|
||||
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()),
|
||||
prepare.fdm_output_strides.GpuPtr(),
|
||||
prepare.fdm_H,
|
||||
prepare.fdm_C,
|
||||
reinterpret_cast<CudaT*>(output_buffer.get()),
|
||||
output_size);
|
||||
|
||||
Impl_Cast<CudaT, ToCudaType<bool>::MappedType>(
|
||||
reinterpret_cast<CudaT*>(output_buffer.get()),
|
||||
reinterpret_cast<ToCudaType<bool>::MappedType*>(output_tensor->template MutableData<bool>()),
|
||||
output_size);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
BINARY_OP_REGISTER_UZILHFD(Sum, 8)
|
||||
BINARY_OP_REGISTER_VERSIONED_UZILHFD(Sum, 6, 7)
|
||||
BINARY_OP_REGISTER_UZILHFD(Greater, 9)
|
||||
BINARY_OP_REGISTER_OIL(Equal, 7)
|
||||
BINARY_OP_REGISTER_VERSIONED_HFD(Greater, 7, 8)
|
||||
BINARY_OP_REGISTER_HFD(Max, 8)
|
||||
BINARY_OP_REGISTER_VERSIONED_HFD(Max, 6, 7)
|
||||
|
|
|
|||
|
|
@ -204,6 +204,15 @@ class Greater final : public CudaKernel {
|
|||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Equal final : public CudaKernel {
|
||||
public:
|
||||
Equal(const OpKernelInfo& info) : CudaKernel(info) {}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
|
||||
template <typename T>
|
||||
class Max final : public CudaKernel {
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -44,6 +44,11 @@ namespace cuda {
|
|||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double)
|
||||
|
||||
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_OIL(x) \
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, bool) \
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int32_t) \
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int64_t)
|
||||
|
||||
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(x) \
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \
|
||||
|
|
@ -82,6 +87,7 @@ SPECIALIZED_BINARY_ELEMENTWISE_IMPL(Or, bool)
|
|||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(Xor, bool)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(PRelu)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Greater)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_OIL(Equal)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(Max)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(Min)
|
||||
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ namespace cuda {
|
|||
BINARY_OP_NAME_EXPR(Xor, (a ^ b)) \
|
||||
BINARY_OP_NAME_EXPR(PRelu, (a > (T)0 ? a : a * b)) \
|
||||
BINARY_OP_NAME_EXPR(Greater, (a > b) ? 1 : 0) \
|
||||
BINARY_OP_NAME_EXPR(Equal, ((a == b) ? 1 : 0)) \
|
||||
BINARY_OP_NAME_EXPR(Max, _Max(a, b)) \
|
||||
BINARY_OP_NAME_EXPR(Min, _Min(a, b))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue