mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
Adding Min cuda kernel (#992)
This commit is contained in:
parent
2de1f43a40
commit
b08dbd37cb
5 changed files with 82 additions and 1 deletions
|
|
@ -329,6 +329,12 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, float, Max);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, double, Max);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, MLFloat16, Max);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 7, float, Min);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 7, double, Min);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 7, MLFloat16, Min);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, float, Min);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, double, Min);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, MLFloat16, Min);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, Greater);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, Greater);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Greater);
|
||||
|
|
@ -619,6 +625,12 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, float, Max)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, double, Max)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, MLFloat16, Max)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 7, float, Min)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 7, double, Min)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 7, MLFloat16, Min)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, float, Min)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, double, Min)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, MLFloat16, Min)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, Greater)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, Greater)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Greater)>,
|
||||
|
|
|
|||
|
|
@ -303,6 +303,62 @@ Status Max<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status Min<T>::ComputeInternal(OpKernelContext* context) const {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
const auto& node = Node();
|
||||
const auto& node_name = node.Name();
|
||||
auto input_count = node.InputArgCount().front();
|
||||
ORT_RETURN_IF_NOT(input_count >= 1, "Must have 1 or more inputs");
|
||||
|
||||
if (input_count == 1) {
|
||||
auto input_tensor = context->Input<Tensor>(0);
|
||||
const auto& input_shape = input_tensor->Shape();
|
||||
auto output_tensor = context->Output(0, input_shape);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor->MutableDataRaw(), input_tensor->DataRaw(), sizeof(CudaT) * input_shape.Size(), cudaMemcpyDeviceToDevice));
|
||||
} else {
|
||||
// compute output shape first, using broadcast rule
|
||||
TensorShape output_shape;
|
||||
ORT_RETURN_IF_ERROR(ComputeOutputShape(node_name, context->Input<Tensor>(0)->Shape(), context->Input<Tensor>(1)->Shape(), output_shape));
|
||||
for (int index = 2; index < input_count; index++) {
|
||||
TensorShape previous_output_shape = output_shape;
|
||||
ORT_RETURN_IF_ERROR(ComputeOutputShape(node_name, previous_output_shape, context->Input<Tensor>(index)->Shape(), output_shape));
|
||||
}
|
||||
Tensor* output_tensor = context->Output(0, output_shape);
|
||||
BinaryElementwisePreparation prepare(this);
|
||||
|
||||
// More than 2 inputs, set output to 0, add input0 to output, so that input0 can be broadcast with output shape correctly
|
||||
CUDA_RETURN_IF_ERROR(cudaMemset(output_tensor->MutableDataRaw(), 0, output_shape.Size() * sizeof(CudaT)));
|
||||
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(0, output_tensor, context->Input<Tensor>(0), output_tensor, &prepare));
|
||||
Impl_Add<CudaT>(
|
||||
prepare.output_rank_or_simple_broadcast,
|
||||
prepare.lhs_padded_strides.GpuPtr(),
|
||||
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()),
|
||||
prepare.rhs_padded_strides.GpuPtr(),
|
||||
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()),
|
||||
prepare.fdm_output_strides.GpuPtr(),
|
||||
prepare.fdm_H,
|
||||
prepare.fdm_C,
|
||||
reinterpret_cast<CudaT*>(prepare.output_tensor->template MutableData<T>()),
|
||||
prepare.output_tensor->Shape().Size());
|
||||
for (int index = 1; index < input_count; index++) {
|
||||
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(0, output_tensor, context->Input<Tensor>(index), output_tensor, &prepare));
|
||||
Impl_Min<CudaT>(
|
||||
prepare.output_rank_or_simple_broadcast,
|
||||
prepare.lhs_padded_strides.GpuPtr(),
|
||||
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()),
|
||||
prepare.rhs_padded_strides.GpuPtr(),
|
||||
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()),
|
||||
prepare.fdm_output_strides.GpuPtr(),
|
||||
prepare.fdm_H,
|
||||
prepare.fdm_C,
|
||||
reinterpret_cast<CudaT*>(prepare.output_tensor->template MutableData<T>()),
|
||||
prepare.output_tensor->Shape().Size());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//Greater op output tensor type is bool, so it cannot directly fit in the macros
|
||||
//for other elementwise ops
|
||||
template <typename T>
|
||||
|
|
@ -347,6 +403,8 @@ BINARY_OP_REGISTER_UZILHFD(Greater, 9)
|
|||
BINARY_OP_REGISTER_VERSIONED_HFD(Greater, 7, 8)
|
||||
BINARY_OP_REGISTER_HFD(Max, 8)
|
||||
BINARY_OP_REGISTER_VERSIONED_HFD(Max, 6, 7)
|
||||
BINARY_OP_REGISTER_HFD(Min, 8)
|
||||
BINARY_OP_REGISTER_VERSIONED_HFD(Min, 6, 7)
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -212,5 +212,14 @@ class Max final : public CudaKernel {
|
|||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Min final : public CudaKernel {
|
||||
public:
|
||||
Min(const OpKernelInfo& info) : CudaKernel(info) {
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -83,6 +83,7 @@ SPECIALIZED_BINARY_ELEMENTWISE_IMPL(Xor, bool)
|
|||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(PRelu)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Greater)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(Max)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(Min)
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -26,7 +26,8 @@ namespace cuda {
|
|||
BINARY_OP_NAME_EXPR(Xor, (a ^ b)) \
|
||||
BINARY_OP_NAME_EXPR(PRelu, (a > (T)0 ? a : a * b)) \
|
||||
BINARY_OP_NAME_EXPR(Greater, (a > b) ? 1 : 0) \
|
||||
BINARY_OP_NAME_EXPR(Max, _Max(a, b))
|
||||
BINARY_OP_NAME_EXPR(Max, _Max(a, b)) \
|
||||
BINARY_OP_NAME_EXPR(Min, _Min(a, b))
|
||||
|
||||
// NOTE that cu files are compiled with nvcc and should not refer to any onnxruntime headers
|
||||
// so struct BinaryElementwisePreparation cannot be used here
|
||||
|
|
|
|||
Loading…
Reference in a new issue