Adding Min cuda kernel (#992)

This commit is contained in:
Du Li 2019-05-09 10:06:30 -07:00 committed by GitHub
parent 2de1f43a40
commit b08dbd37cb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 82 additions and 1 deletions

View file

@ -329,6 +329,12 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, float, Max);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, double, Max);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, MLFloat16, Max);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 7, float, Min);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 7, double, Min);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 7, MLFloat16, Min);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, float, Min);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, double, Min);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, MLFloat16, Min);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, Greater);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, Greater);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Greater);
@ -619,6 +625,12 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, float, Max)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, double, Max)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, MLFloat16, Max)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 7, float, Min)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 7, double, Min)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 7, MLFloat16, Min)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, float, Min)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, double, Min)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, MLFloat16, Min)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, Greater)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, Greater)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Greater)>,

View file

@ -303,6 +303,62 @@ Status Max<T>::ComputeInternal(OpKernelContext* context) const {
return Status::OK();
}
template <typename T>
Status Min<T>::ComputeInternal(OpKernelContext* context) const {
typedef typename ToCudaType<T>::MappedType CudaT;
const auto& node = Node();
const auto& node_name = node.Name();
auto input_count = node.InputArgCount().front();
ORT_RETURN_IF_NOT(input_count >= 1, "Must have 1 or more inputs");
if (input_count == 1) {
auto input_tensor = context->Input<Tensor>(0);
const auto& input_shape = input_tensor->Shape();
auto output_tensor = context->Output(0, input_shape);
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor->MutableDataRaw(), input_tensor->DataRaw(), sizeof(CudaT) * input_shape.Size(), cudaMemcpyDeviceToDevice));
} else {
// compute output shape first, using broadcast rule
TensorShape output_shape;
ORT_RETURN_IF_ERROR(ComputeOutputShape(node_name, context->Input<Tensor>(0)->Shape(), context->Input<Tensor>(1)->Shape(), output_shape));
for (int index = 2; index < input_count; index++) {
TensorShape previous_output_shape = output_shape;
ORT_RETURN_IF_ERROR(ComputeOutputShape(node_name, previous_output_shape, context->Input<Tensor>(index)->Shape(), output_shape));
}
Tensor* output_tensor = context->Output(0, output_shape);
BinaryElementwisePreparation prepare(this);
// More than 2 inputs, set output to 0, add input0 to output, so that input0 can be broadcast with output shape correctly
CUDA_RETURN_IF_ERROR(cudaMemset(output_tensor->MutableDataRaw(), 0, output_shape.Size() * sizeof(CudaT)));
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(0, output_tensor, context->Input<Tensor>(0), output_tensor, &prepare));
Impl_Add<CudaT>(
prepare.output_rank_or_simple_broadcast,
prepare.lhs_padded_strides.GpuPtr(),
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()),
prepare.rhs_padded_strides.GpuPtr(),
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()),
prepare.fdm_output_strides.GpuPtr(),
prepare.fdm_H,
prepare.fdm_C,
reinterpret_cast<CudaT*>(prepare.output_tensor->template MutableData<T>()),
prepare.output_tensor->Shape().Size());
for (int index = 1; index < input_count; index++) {
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(0, output_tensor, context->Input<Tensor>(index), output_tensor, &prepare));
Impl_Min<CudaT>(
prepare.output_rank_or_simple_broadcast,
prepare.lhs_padded_strides.GpuPtr(),
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()),
prepare.rhs_padded_strides.GpuPtr(),
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()),
prepare.fdm_output_strides.GpuPtr(),
prepare.fdm_H,
prepare.fdm_C,
reinterpret_cast<CudaT*>(prepare.output_tensor->template MutableData<T>()),
prepare.output_tensor->Shape().Size());
}
}
return Status::OK();
}
//Greater op output tensor type is bool, so it cannot directly fit in the macros
//for other elementwise ops
template <typename T>
@ -347,6 +403,8 @@ BINARY_OP_REGISTER_UZILHFD(Greater, 9)
BINARY_OP_REGISTER_VERSIONED_HFD(Greater, 7, 8)
BINARY_OP_REGISTER_HFD(Max, 8)
BINARY_OP_REGISTER_VERSIONED_HFD(Max, 6, 7)
BINARY_OP_REGISTER_HFD(Min, 8)
BINARY_OP_REGISTER_VERSIONED_HFD(Min, 6, 7)
} // namespace cuda
} // namespace onnxruntime

View file

@ -212,5 +212,14 @@ class Max final : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Min final : public CudaKernel {
public:
Min(const OpKernelInfo& info) : CudaKernel(info) {
}
Status ComputeInternal(OpKernelContext* context) const override;
};
} // namespace cuda
} // namespace onnxruntime

View file

@ -83,6 +83,7 @@ SPECIALIZED_BINARY_ELEMENTWISE_IMPL(Xor, bool)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(PRelu)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Greater)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(Max)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(Min)
} // namespace cuda
} // namespace onnxruntime

View file

@ -26,7 +26,8 @@ namespace cuda {
BINARY_OP_NAME_EXPR(Xor, (a ^ b)) \
BINARY_OP_NAME_EXPR(PRelu, (a > (T)0 ? a : a * b)) \
BINARY_OP_NAME_EXPR(Greater, (a > b) ? 1 : 0) \
BINARY_OP_NAME_EXPR(Max, _Max(a, b))
BINARY_OP_NAME_EXPR(Max, _Max(a, b)) \
BINARY_OP_NAME_EXPR(Min, _Min(a, b))
// NOTE that cu files are compiled with nvcc and should not refer to any onnxruntime headers
// so struct BinaryElementwisePreparation cannot be used here