Implement Equal for CUDA. (#1183)

This commit is contained in:
Dmitri Smirnov 2019-06-07 11:11:50 -07:00 committed by GitHub
parent d33dbb23b2
commit e43e64bf84
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 64 additions and 0 deletions

View file

@ -347,6 +347,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, Greater);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, Greater);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Greater);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, bool, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int32_t, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int64_t, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int32_t, Greater);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int64_t, Greater);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, uint32_t, Greater);
@ -675,6 +678,9 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, Greater)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, Greater)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Greater)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, bool, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int32_t, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int64_t, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int32_t, Greater)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int64_t, Greater)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, uint32_t, Greater)>,

View file

@ -154,6 +154,11 @@ Status BinaryElementwise<ShouldBroadcast>::Prepare(OpKernelContext* context, int
BINARY_OP_TYPED(name, ver, int64_t) \
BINARY_OP_HFD(name, ver)
#define BINARY_OP_REGISTER_OIL(name, ver) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, bool) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, int32_t) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, int64_t)
#define BINARY_OP_REGISTER_HFD(name, ver) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, MLFloat16) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, float) \
@ -397,9 +402,46 @@ Status Greater<T>::ComputeInternal(OpKernelContext* context) const {
return Status::OK();
}
template <typename T>
Status Equal<T>::ComputeInternal(OpKernelContext* context) const {
typedef typename ToCudaType<T>::MappedType CudaT;
const onnxruntime::Node& node = OpKernel::Node();
const std::string& name = node.Name();
const Tensor* input0 = context->Input<Tensor>(0);
const Tensor* input1 = context->Input<Tensor>(1);
TensorShape output_shape;
ORT_RETURN_IF_ERROR(ComputeOutputShape(name, input0->Shape(), input1->Shape(), output_shape));
size_t output_size = output_shape.Size();
Tensor* output_tensor = context->Output(0, output_shape);
BinaryElementwisePreparation prepare(this);
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(0, input0, input1, output_tensor, &prepare));
IAllocatorUniquePtr<T> output_buffer = GetScratchBuffer<T>(output_size);
Impl_Equal<CudaT>(
prepare.output_rank_or_simple_broadcast,
prepare.lhs_padded_strides.GpuPtr(),
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()),
prepare.rhs_padded_strides.GpuPtr(),
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()),
prepare.fdm_output_strides.GpuPtr(),
prepare.fdm_H,
prepare.fdm_C,
reinterpret_cast<CudaT*>(output_buffer.get()),
output_size);
Impl_Cast<CudaT, ToCudaType<bool>::MappedType>(
reinterpret_cast<CudaT*>(output_buffer.get()),
reinterpret_cast<ToCudaType<bool>::MappedType*>(output_tensor->template MutableData<bool>()),
output_size);
return Status::OK();
}
BINARY_OP_REGISTER_UZILHFD(Sum, 8)
BINARY_OP_REGISTER_VERSIONED_UZILHFD(Sum, 6, 7)
BINARY_OP_REGISTER_UZILHFD(Greater, 9)
BINARY_OP_REGISTER_OIL(Equal, 7)
BINARY_OP_REGISTER_VERSIONED_HFD(Greater, 7, 8)
BINARY_OP_REGISTER_HFD(Max, 8)
BINARY_OP_REGISTER_VERSIONED_HFD(Max, 6, 7)

View file

@ -204,6 +204,15 @@ class Greater final : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Equal final : public CudaKernel {
public:
Equal(const OpKernelInfo& info) : CudaKernel(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Max final : public CudaKernel {
public:

View file

@ -44,6 +44,11 @@ namespace cuda {
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double)
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_OIL(x) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, bool) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int64_t)
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(x) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \
@ -82,6 +87,7 @@ SPECIALIZED_BINARY_ELEMENTWISE_IMPL(Or, bool)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(Xor, bool)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(PRelu)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Greater)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_OIL(Equal)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(Max)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(Min)

View file

@ -26,6 +26,7 @@ namespace cuda {
BINARY_OP_NAME_EXPR(Xor, (a ^ b)) \
BINARY_OP_NAME_EXPR(PRelu, (a > (T)0 ? a : a * b)) \
BINARY_OP_NAME_EXPR(Greater, (a > b) ? 1 : 0) \
BINARY_OP_NAME_EXPR(Equal, ((a == b) ? 1 : 0)) \
BINARY_OP_NAME_EXPR(Max, _Max(a, b)) \
BINARY_OP_NAME_EXPR(Min, _Min(a, b))