mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-30 23:18:20 +00:00
Implement Range Cuda Kernel to improve performance (#2148)
This commit is contained in:
parent
7efc9bdcc7
commit
18b192a45b
5 changed files with 189 additions and 0 deletions
|
|
@ -546,6 +546,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Gemm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Gemm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Gemm);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Range);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, Scatter);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, ScatterElements);
|
||||
|
|
@ -894,6 +895,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Range)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, Scatter)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, ScatterElements)>,
|
||||
|
|
|
|||
107
onnxruntime/core/providers/cuda/generator/range.cc
Normal file
107
onnxruntime/core/providers/cuda/generator/range.cc
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "range.h"
|
||||
#include "range_impl.h"
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
using namespace ::onnxruntime::common;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Range,
|
||||
kOnnxDomain,
|
||||
11,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<double>(),
|
||||
DataTypeImpl::GetTensorType<int16_t>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
Range);
|
||||
|
||||
template <typename T>
|
||||
static Status ComputeRange(OpKernelContext* ctx) {
|
||||
const auto& start_tensor = *ctx->Input<Tensor>(0);
|
||||
const auto& limit_tensor = *ctx->Input<Tensor>(1);
|
||||
const auto* delta_tensor_ptr = ctx->Input<Tensor>(2);
|
||||
|
||||
if (!start_tensor.Shape().IsScalar()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"start in Range operator should be scalar like tensor, yet got shape:",
|
||||
start_tensor.Shape());
|
||||
}
|
||||
if (!limit_tensor.Shape().IsScalar()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"limit in Range operator should be scalar like tensor, yet got shape:",
|
||||
limit_tensor.Shape());
|
||||
}
|
||||
if (delta_tensor_ptr != nullptr && !delta_tensor_ptr->Shape().IsScalar()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"delta in Range operator should be scalar like tensor, yet got shape:",
|
||||
delta_tensor_ptr->Shape());
|
||||
}
|
||||
|
||||
// Start, Limit and Delta are stored in GPU. So we need copy it to CPU to read.
|
||||
// It is better to store these tensors in pinned memory or CPU for better performance.
|
||||
T start;
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpy(&start, start_tensor.template Data<T>(), sizeof(T), cudaMemcpyDeviceToHost));
|
||||
|
||||
T limit;
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpy(&limit, limit_tensor.template Data<T>(), sizeof(T), cudaMemcpyDeviceToHost));
|
||||
|
||||
T delta = T(1);
|
||||
if (delta_tensor_ptr != nullptr) {
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpy(&delta, delta_tensor_ptr->template Data<T>(), sizeof(T), cudaMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
if (delta == T(0)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "delta in Range operator can not be zero!");
|
||||
}
|
||||
|
||||
int count = static_cast<int>(ceil(1.0 * (limit - start) / delta));
|
||||
if (count <= 0)
|
||||
count = 0;
|
||||
TensorShape shape = {static_cast<int64_t>(count)};
|
||||
T* y = ctx->Output(0, shape)->template MutableData<T>();
|
||||
|
||||
if (count > 0) {
|
||||
if (!RangeImpl(start, delta, count, y)) {
|
||||
CUDA_CALL(cudaGetLastError());
|
||||
return Status(common::ONNXRUNTIME, common::FAIL);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Range::ComputeInternal(OpKernelContext* ctx) const {
|
||||
const auto* input_tensor = ctx->Input<Tensor>(0);
|
||||
if (input_tensor == nullptr) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
|
||||
}
|
||||
|
||||
auto data_type = input_tensor->DataType();
|
||||
if (data_type == DataTypeImpl::GetType<int32_t>()) {
|
||||
return ComputeRange<int32_t>(ctx);
|
||||
} else if (data_type == DataTypeImpl::GetType<float>()) {
|
||||
return ComputeRange<float>(ctx);
|
||||
} else if (data_type == DataTypeImpl::GetType<int64_t>()) {
|
||||
return ComputeRange<int64_t>(ctx);
|
||||
} else if (data_type == DataTypeImpl::GetType<double>()) {
|
||||
return ComputeRange<double>(ctx);
|
||||
} else if (data_type == DataTypeImpl::GetType<int16_t>()) {
|
||||
return ComputeRange<int16_t>(ctx);
|
||||
}
|
||||
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Range op: Unsupported tensor data type:", data_type);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
22
onnxruntime/core/providers/cuda/generator/range.h
Normal file
22
onnxruntime/core/providers/cuda/generator/range.h
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
|
||||
class Range final : public CudaKernel {
|
||||
public:
|
||||
explicit Range(const OpKernelInfo& info) : CudaKernel(info) {}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
42
onnxruntime/core/providers/cuda/generator/range_impl.cu
Normal file
42
onnxruntime/core/providers/cuda/generator/range_impl.cu
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <cub/cub.cuh>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "range_impl.h"
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
__global__ void RangeKernel(const T start, const T delta, const int count, T* output) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (index < count) {
|
||||
output[index] = start + delta * index;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool RangeImpl(const T start, const T delta, const int count, T* output) {
|
||||
constexpr int block_size = 256;
|
||||
int grid_size = (count + block_size - 1) / block_size;
|
||||
RangeKernel<T><<<grid_size, block_size, 0>>>(start, delta, count, output);
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
#define SPECIALIZED_IMPL(T) \
|
||||
template bool RangeImpl<T>(const T start, const T delta, const int count, T* output);
|
||||
|
||||
SPECIALIZED_IMPL(int16_t)
|
||||
SPECIALIZED_IMPL(int32_t)
|
||||
SPECIALIZED_IMPL(int64_t)
|
||||
SPECIALIZED_IMPL(float)
|
||||
SPECIALIZED_IMPL(double)
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
16
onnxruntime/core/providers/cuda/generator/range_impl.h
Normal file
16
onnxruntime/core/providers/cuda/generator/range_impl.h
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
|
||||
template <typename T>
|
||||
bool RangeImpl(const T start, const T delta, const int count, T* output);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue