mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-16 01:33:39 +00:00
Cuda Clip() for op set 11. (#2411)
* Cuda Clip() for op set 11. * make min_val and max_value input CPU memory directly. * Remove original cu file useless "#pragma once" * merge duplicate logic into one class.
This commit is contained in:
parent
ccbd778d0d
commit
04b6097db4
4 changed files with 60 additions and 28 deletions
|
|
@ -363,7 +363,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Ceil);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, Ceil);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, Ceil);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Clip);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 10, float, Clip);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Reciprocal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, Reciprocal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, Reciprocal);
|
||||
|
|
@ -652,6 +652,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, MaxPool);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, MaxPool);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, MaxPool);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Clip);
|
||||
|
||||
static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
|
|
@ -676,7 +677,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, MatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, MatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, MatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Clip)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 10, float, Clip)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Tile)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, Tile)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, Tile)>,
|
||||
|
|
@ -1099,6 +1100,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Clip)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
|
|
@ -8,38 +8,61 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
Clip, \
|
||||
kOnnxDomain, \
|
||||
6, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
|
||||
Clip, \
|
||||
kOnnxDomain, \
|
||||
6, \
|
||||
10, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
Clip<T>); \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
Clip, \
|
||||
kOnnxDomain, \
|
||||
11, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(1) \
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(2) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
Clip<T>);
|
||||
|
||||
template <typename T>
|
||||
Status Clip<T>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
T min_val = min_;
|
||||
T max_val = max_;
|
||||
if (is_min_max_input_) {
|
||||
const auto* min_input = ctx->Input<Tensor>(1);
|
||||
const auto* max_input = ctx->Input<Tensor>(2);
|
||||
if (min_input) {
|
||||
ORT_ENFORCE(min_input->Shape().NumDimensions() == 0, "min should be a scalar.");
|
||||
min_val = *(min_input->template Data<T>());
|
||||
}
|
||||
if (max_input) {
|
||||
ORT_ENFORCE(max_input->Shape().NumDimensions() == 0, "max should be a scalar.");
|
||||
max_val = *(max_input->template Data<T>());
|
||||
}
|
||||
ORT_ENFORCE(min_val <= max_val);
|
||||
}
|
||||
|
||||
const Tensor& X = *ctx->Input<Tensor>(0);
|
||||
const TensorShape input_shape{X.Shape()};
|
||||
const TensorShape& input_shape{X.Shape()};
|
||||
Tensor* Y = ctx->Output(0, input_shape);
|
||||
|
||||
size_t count = input_shape.Size();
|
||||
|
||||
if (count > 0) {
|
||||
auto* y_data = Y->template MutableData<T>();
|
||||
const auto* x_data = X.template Data<T>();
|
||||
ClipImpl<T>(x_data, y_data, min_, max_, count);
|
||||
ClipImpl<T>(x_data, y_data, min_val, max_val, count);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define SPECIALIZED_COMPUTE(T) \
|
||||
REGISTER_KERNEL_TYPED(T) \
|
||||
template Status Clip<T>::ComputeInternal(OpKernelContext* ctx) const;
|
||||
|
||||
SPECIALIZED_COMPUTE(float)
|
||||
REGISTER_KERNEL_TYPED(float)
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -10,21 +10,29 @@ namespace cuda {
|
|||
template <typename T>
|
||||
class Clip final : public CudaKernel {
|
||||
public:
|
||||
Clip(const OpKernelInfo& info) : CudaKernel{info} {
|
||||
auto min_val = -std::numeric_limits<T>::infinity();
|
||||
auto max_val = std::numeric_limits<T>::infinity();
|
||||
Clip(const OpKernelInfo& info) : CudaKernel{info}, is_min_max_input_(false) {
|
||||
int start_version;
|
||||
int end_version;
|
||||
info.GetKernelDef().SinceVersion(&start_version, &end_version);
|
||||
|
||||
info.GetAttrOrDefault("min", &min_, min_val);
|
||||
info.GetAttrOrDefault("max", &max_, max_val);
|
||||
|
||||
// Make sure the range of interval is sensible
|
||||
ORT_ENFORCE(min_val <= max_val);
|
||||
if (start_version < 11) {
|
||||
auto min_val = -std::numeric_limits<T>::infinity();
|
||||
auto max_val = std::numeric_limits<T>::infinity();
|
||||
info.GetAttrOrDefault("min", &min_, min_val);
|
||||
info.GetAttrOrDefault("max", &max_, max_val);
|
||||
ORT_ENFORCE(min_ <= max_);
|
||||
} else {
|
||||
min_ = -std::numeric_limits<T>::infinity();
|
||||
max_ = std::numeric_limits<T>::infinity();
|
||||
is_min_max_input_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
T min_, max_;
|
||||
bool is_min_max_input_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/providers/cuda/math/clip_impl.h"
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue