mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
Rashuai/cuda top k (#1919)
* implement cuda topk * implement heap * add type support * refactor interface * add support for sorting by index * add test case * use cub device radix sort * register for opset 9 and 10 * add opset 9/10 delaration * refactor code * refactor code * fix comment * fix comment * switch to scratched mem
This commit is contained in:
parent
4bcd8bfca1
commit
d6849bd26c
5 changed files with 244 additions and 0 deletions
|
|
@ -564,6 +564,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, R
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, Scatter);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, ScatterElements);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 9, TopK);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, TopK);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, TopK);
|
||||
|
||||
static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MemcpyFromHost)>,
|
||||
|
|
@ -919,6 +923,11 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, Scatter)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, ScatterElements)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 9, TopK)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, TopK)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, TopK)>,
|
||||
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
101
onnxruntime/core/providers/cuda/math/topk.cc
Normal file
101
onnxruntime/core/providers/cuda/math/topk.cc
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "topk.h"
|
||||
#include "topk_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
TopK,
|
||||
kOnnxDomain,
|
||||
1,9,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
|
||||
TopK<false>);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
TopK,
|
||||
kOnnxDomain,
|
||||
10,10,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder().InputMemoryType<OrtMemTypeCPUInput>(1).TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
|
||||
TopK<true>);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
TopK,
|
||||
kOnnxDomain,
|
||||
11,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder().InputMemoryType<OrtMemTypeCPUInput>(1).TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
|
||||
TopK<true>);
|
||||
|
||||
template <bool inputk>
|
||||
TopK<inputk>::TopK(const OpKernelInfo& info) : CudaKernel(info) {
|
||||
info.GetAttrOrDefault<int64_t>("axis", &axis_, -1);
|
||||
info.GetAttrOrDefault<int64_t>("largest", &largest_, 1);
|
||||
info.GetAttrOrDefault<int64_t>("sorted", &sorted_, 1);
|
||||
if (!inputk) {
|
||||
info.GetAttrOrDefault<int64_t>("k", &K_, 0);
|
||||
}
|
||||
}
|
||||
|
||||
#define ISTYPE(T) tensor_X->DataType() == DataTypeImpl::GetType<T>()
|
||||
#define TOPKIMPL(T) TopKImpl<T>(this, tensor_X->Data<T>(), \
|
||||
static_cast<T*>(tensor_V->MutableDataRaw()), \
|
||||
static_cast<int64_t*>(tensor_I->MutableDataRaw()), \
|
||||
elem_nums_cuda.GpuPtr(), \
|
||||
elem_nums.size(), \
|
||||
axis, K_, largest_, sorted_, N, dimension)
|
||||
|
||||
template <bool inputk>
|
||||
Status TopK<inputk>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
auto tensor_X = ctx->Input<Tensor>(0);
|
||||
ORT_ENFORCE(nullptr != tensor_X);
|
||||
auto rank = static_cast<int64_t>(tensor_X->Shape().NumDimensions());
|
||||
auto axis = axis_ < 0 ? rank + axis_ : axis_;
|
||||
ORT_ENFORCE(axis > -1 && axis < rank);
|
||||
|
||||
if (inputk) {
|
||||
auto tensor_K = ctx->Input<Tensor>(1);
|
||||
ORT_ENFORCE(nullptr != tensor_K);
|
||||
K_ = *tensor_K->Data<int64_t>();
|
||||
ORT_ENFORCE(K_ >= 0 && K_ <= tensor_X->Shape().GetDims()[axis]);
|
||||
}
|
||||
|
||||
auto output_shape = tensor_X->Shape();
|
||||
output_shape[axis] = K_;
|
||||
auto tensor_V = ctx->Output(0, output_shape);
|
||||
auto tensor_I = ctx->Output(1, output_shape);
|
||||
|
||||
if (0 == K_) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto elem_nums = tensor_X->Shape().GetDims();
|
||||
auto dimension = elem_nums[axis];
|
||||
for (auto i = static_cast<int32_t>(elem_nums.size()) - 2; i >= 0; --i) {
|
||||
elem_nums[i] *= elem_nums[i + 1];
|
||||
}
|
||||
|
||||
auto N = elem_nums[0] / dimension;
|
||||
CudaAsyncBuffer<int64_t> elem_nums_cuda(this, elem_nums);
|
||||
ORT_RETURN_IF_ERROR(elem_nums_cuda.CopyToGpu());
|
||||
|
||||
if (ISTYPE(uint8_t)) return TOPKIMPL(uint8_t);
|
||||
if (ISTYPE(uint16_t)) return TOPKIMPL(uint16_t);
|
||||
if (ISTYPE(uint32_t)) return TOPKIMPL(uint32_t);
|
||||
if (ISTYPE(uint64_t)) return TOPKIMPL(uint64_t);
|
||||
if (ISTYPE(int8_t)) return TOPKIMPL(int8_t);
|
||||
if (ISTYPE(int16_t)) return TOPKIMPL(int16_t);
|
||||
if (ISTYPE(int32_t)) return TOPKIMPL(int32_t);
|
||||
if (ISTYPE(int64_t)) return TOPKIMPL(int64_t);
|
||||
if (ISTYPE(float)) return TOPKIMPL(float);
|
||||
if (ISTYPE(double)) return TOPKIMPL(double);
|
||||
if (ISTYPE(uint8_t)) return TOPKIMPL(uint8_t);
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for TopK operator");
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
23
onnxruntime/core/providers/cuda/math/topk.h
Normal file
23
onnxruntime/core/providers/cuda/math/topk.h
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template<bool inputk>
|
||||
class TopK final : public CudaKernel {
|
||||
public:
|
||||
TopK(const OpKernelInfo&);
|
||||
Status ComputeInternal(OpKernelContext*) const override;
|
||||
|
||||
private:
|
||||
int64_t axis_;
|
||||
int64_t largest_;
|
||||
int64_t sorted_;
|
||||
mutable int64_t K_;
|
||||
};
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
94
onnxruntime/core/providers/cuda/math/topk_impl.cu
Normal file
94
onnxruntime/core/providers/cuda/math/topk_impl.cu
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "topk_impl.h"
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "cub/cub.cuh"
|
||||
#include <limits>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
__global__ void FillInput(const T* input_x, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t offset, int64_t dimension) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, dimension);
|
||||
auto left = offset / (axis == size - 1 ? 1 : elem_nums[axis + 1]) * elem_nums[axis];
|
||||
auto right = axis == size - 1 ? 0 : offset % elem_nums[axis + 1];
|
||||
auto input_offset = left + id * (axis == size - 1 ? 1 : elem_nums[axis + 1]) + right;
|
||||
output_v[id] = input_x[input_offset];
|
||||
output_i[id] = id;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void FillOutput(const T* input_v, const int64_t* input_i, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t offset, int64_t dimension) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, dimension);
|
||||
auto left = offset / (axis == size - 1 ? 1 : elem_nums[axis + 1]) * elem_nums[axis] * K / dimension;
|
||||
auto right = axis == size - 1 ? 0 : offset % elem_nums[axis + 1];
|
||||
auto output_offset = left + id * (axis == size - 1 ? 1 : elem_nums[axis + 1]) + right;
|
||||
output_v[output_offset] = input_v[id];
|
||||
output_i[output_offset] = input_i[id];
|
||||
}
|
||||
|
||||
__global__ void ExcludeOutput(int64_t* output_i, int64_t K, int64_t dimension) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, dimension);
|
||||
if (id >= K) {
|
||||
output_i[id] = dimension;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension) {
|
||||
auto input_key_buffer = kernel->GetScratchBuffer<T>(dimension);
|
||||
auto output_key_buffer = kernel->GetScratchBuffer<T>(dimension);
|
||||
auto input_value_buffer = kernel->GetScratchBuffer<int64_t>(dimension);
|
||||
auto output_value_buffer = kernel->GetScratchBuffer<int64_t>(dimension);
|
||||
auto input_key = input_key_buffer.get();
|
||||
auto output_key = output_key_buffer.get();
|
||||
auto input_value = input_value_buffer.get();
|
||||
auto output_value = output_value_buffer.get();
|
||||
size_t temp_bytes = 0;
|
||||
CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(nullptr, temp_bytes, input_key, output_key, input_value, output_value, dimension));
|
||||
auto temp_storage_buffer = kernel->GetScratchBuffer<char>(temp_bytes);
|
||||
auto temp_storage = temp_storage_buffer.get();
|
||||
auto blocksPerGridD = (int)(ceil(static_cast<float>(dimension) / GridDim::maxThreadsPerBlock));
|
||||
auto blocksPerGridK = (int)(ceil(static_cast<float>(K) / GridDim::maxThreadsPerBlock));
|
||||
for (int64_t i = 0; i < N; i++) {
|
||||
FillInput<T><<<blocksPerGridD, GridDim::maxThreadsPerBlock, 0>>>(input_x, input_key, input_value, elem_nums, size, axis, K, i, dimension);
|
||||
CUDA_RETURN_IF_ERROR(1 == largest ? cub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension) : cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension));
|
||||
if (1 == sorted) {
|
||||
FillOutput<T><<<blocksPerGridK, GridDim::maxThreadsPerBlock, 0>>>(output_key, output_value, output_v, output_i, elem_nums, size, axis, K, i, dimension);
|
||||
} else { //reorder by ascending index
|
||||
ExcludeOutput<<<blocksPerGridD, GridDim::maxThreadsPerBlock, 0>>>(output_value, K, dimension);
|
||||
CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, output_value, input_value, output_key, input_key, dimension));
|
||||
FillOutput<T><<<blocksPerGridK, GridDim::maxThreadsPerBlock, 0>>>(input_key, input_value, output_v, output_i, elem_nums, size, axis, K, i, dimension);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define TOPKIMPLE(T) template Status TopKImpl<T>(const CudaKernel* kernel, \
|
||||
const T* input_x, \
|
||||
T* output_v, \
|
||||
int64_t* output_i, \
|
||||
const int64_t* elem_nums, \
|
||||
size_t size, \
|
||||
int64_t axis, \
|
||||
int64_t K, \
|
||||
int64_t largest, \
|
||||
int64_t sorted, \
|
||||
int64_t N, \
|
||||
int64_t dimension)
|
||||
|
||||
TOPKIMPLE(uint8_t);
|
||||
TOPKIMPLE(uint16_t);
|
||||
TOPKIMPLE(uint32_t);
|
||||
TOPKIMPLE(uint64_t);
|
||||
TOPKIMPLE(int8_t);
|
||||
TOPKIMPLE(int16_t);
|
||||
TOPKIMPLE(int32_t);
|
||||
TOPKIMPLE(int64_t);
|
||||
TOPKIMPLE(float);
|
||||
TOPKIMPLE(double);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
17
onnxruntime/core/providers/cuda/math/topk_impl.h
Normal file
17
onnxruntime/core/providers/cuda/math/topk_impl.h
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include <stdint.h>
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/common/common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue