Rashuai/cuda top k (#1919)

* implement cuda topk

* implement heap

* add type support

* refactor interface

* add support for sorting by index

* add test case

* use cub device radix sort

* register for opset 9 and 10

* add opset 9/10 delaration

* refactor code

* refactor code

* fix comment

* fix comment

* switch to scratched mem
This commit is contained in:
RandySheriffH 2019-10-31 10:26:00 -07:00 committed by GitHub
parent 4bcd8bfca1
commit d6849bd26c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 244 additions and 0 deletions

View file

@ -564,6 +564,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, R
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, Scatter);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, ScatterElements);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 9, TopK);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, TopK);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, TopK);
static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MemcpyFromHost)>,
@ -919,6 +923,11 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, Scatter)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, ScatterElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 9, TopK)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, TopK)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, TopK)>,
};
for (auto& function_table_entry : function_table) {

View file

@ -0,0 +1,101 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "topk.h"
#include "topk_impl.h"
namespace onnxruntime {
namespace cuda {
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
TopK,
kOnnxDomain,
1,9,
kCudaExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
TopK<false>);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
TopK,
kOnnxDomain,
10,10,
kCudaExecutionProvider,
KernelDefBuilder().InputMemoryType<OrtMemTypeCPUInput>(1).TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
TopK<true>);
ONNX_OPERATOR_KERNEL_EX(
TopK,
kOnnxDomain,
11,
kCudaExecutionProvider,
KernelDefBuilder().InputMemoryType<OrtMemTypeCPUInput>(1).TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
TopK<true>);
template <bool inputk>
TopK<inputk>::TopK(const OpKernelInfo& info) : CudaKernel(info) {
info.GetAttrOrDefault<int64_t>("axis", &axis_, -1);
info.GetAttrOrDefault<int64_t>("largest", &largest_, 1);
info.GetAttrOrDefault<int64_t>("sorted", &sorted_, 1);
if (!inputk) {
info.GetAttrOrDefault<int64_t>("k", &K_, 0);
}
}
#define ISTYPE(T) tensor_X->DataType() == DataTypeImpl::GetType<T>()
#define TOPKIMPL(T) TopKImpl<T>(this, tensor_X->Data<T>(), \
static_cast<T*>(tensor_V->MutableDataRaw()), \
static_cast<int64_t*>(tensor_I->MutableDataRaw()), \
elem_nums_cuda.GpuPtr(), \
elem_nums.size(), \
axis, K_, largest_, sorted_, N, dimension)
template <bool inputk>
Status TopK<inputk>::ComputeInternal(OpKernelContext* ctx) const {
auto tensor_X = ctx->Input<Tensor>(0);
ORT_ENFORCE(nullptr != tensor_X);
auto rank = static_cast<int64_t>(tensor_X->Shape().NumDimensions());
auto axis = axis_ < 0 ? rank + axis_ : axis_;
ORT_ENFORCE(axis > -1 && axis < rank);
if (inputk) {
auto tensor_K = ctx->Input<Tensor>(1);
ORT_ENFORCE(nullptr != tensor_K);
K_ = *tensor_K->Data<int64_t>();
ORT_ENFORCE(K_ >= 0 && K_ <= tensor_X->Shape().GetDims()[axis]);
}
auto output_shape = tensor_X->Shape();
output_shape[axis] = K_;
auto tensor_V = ctx->Output(0, output_shape);
auto tensor_I = ctx->Output(1, output_shape);
if (0 == K_) {
return Status::OK();
}
auto elem_nums = tensor_X->Shape().GetDims();
auto dimension = elem_nums[axis];
for (auto i = static_cast<int32_t>(elem_nums.size()) - 2; i >= 0; --i) {
elem_nums[i] *= elem_nums[i + 1];
}
auto N = elem_nums[0] / dimension;
CudaAsyncBuffer<int64_t> elem_nums_cuda(this, elem_nums);
ORT_RETURN_IF_ERROR(elem_nums_cuda.CopyToGpu());
if (ISTYPE(uint8_t)) return TOPKIMPL(uint8_t);
if (ISTYPE(uint16_t)) return TOPKIMPL(uint16_t);
if (ISTYPE(uint32_t)) return TOPKIMPL(uint32_t);
if (ISTYPE(uint64_t)) return TOPKIMPL(uint64_t);
if (ISTYPE(int8_t)) return TOPKIMPL(int8_t);
if (ISTYPE(int16_t)) return TOPKIMPL(int16_t);
if (ISTYPE(int32_t)) return TOPKIMPL(int32_t);
if (ISTYPE(int64_t)) return TOPKIMPL(int64_t);
if (ISTYPE(float)) return TOPKIMPL(float);
if (ISTYPE(double)) return TOPKIMPL(double);
if (ISTYPE(uint8_t)) return TOPKIMPL(uint8_t);
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for TopK operator");
}
} // namespace cuda
} // namespace onnxruntime

View file

@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/cuda/cuda_common.h"
namespace onnxruntime {
namespace cuda {
template<bool inputk>
class TopK final : public CudaKernel {
public:
TopK(const OpKernelInfo&);
Status ComputeInternal(OpKernelContext*) const override;
private:
int64_t axis_;
int64_t largest_;
int64_t sorted_;
mutable int64_t K_;
};
} // namespace cuda
} // namespace onnxruntime

View file

@ -0,0 +1,94 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "topk_impl.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "cub/cub.cuh"
#include <limits>
namespace onnxruntime {
namespace cuda {
template <typename T>
__global__ void FillInput(const T* input_x, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t offset, int64_t dimension) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, dimension);
auto left = offset / (axis == size - 1 ? 1 : elem_nums[axis + 1]) * elem_nums[axis];
auto right = axis == size - 1 ? 0 : offset % elem_nums[axis + 1];
auto input_offset = left + id * (axis == size - 1 ? 1 : elem_nums[axis + 1]) + right;
output_v[id] = input_x[input_offset];
output_i[id] = id;
}
template <typename T>
__global__ void FillOutput(const T* input_v, const int64_t* input_i, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t offset, int64_t dimension) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, dimension);
auto left = offset / (axis == size - 1 ? 1 : elem_nums[axis + 1]) * elem_nums[axis] * K / dimension;
auto right = axis == size - 1 ? 0 : offset % elem_nums[axis + 1];
auto output_offset = left + id * (axis == size - 1 ? 1 : elem_nums[axis + 1]) + right;
output_v[output_offset] = input_v[id];
output_i[output_offset] = input_i[id];
}
__global__ void ExcludeOutput(int64_t* output_i, int64_t K, int64_t dimension) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, dimension);
if (id >= K) {
output_i[id] = dimension;
}
}
template <typename T>
Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension) {
auto input_key_buffer = kernel->GetScratchBuffer<T>(dimension);
auto output_key_buffer = kernel->GetScratchBuffer<T>(dimension);
auto input_value_buffer = kernel->GetScratchBuffer<int64_t>(dimension);
auto output_value_buffer = kernel->GetScratchBuffer<int64_t>(dimension);
auto input_key = input_key_buffer.get();
auto output_key = output_key_buffer.get();
auto input_value = input_value_buffer.get();
auto output_value = output_value_buffer.get();
size_t temp_bytes = 0;
CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(nullptr, temp_bytes, input_key, output_key, input_value, output_value, dimension));
auto temp_storage_buffer = kernel->GetScratchBuffer<char>(temp_bytes);
auto temp_storage = temp_storage_buffer.get();
auto blocksPerGridD = (int)(ceil(static_cast<float>(dimension) / GridDim::maxThreadsPerBlock));
auto blocksPerGridK = (int)(ceil(static_cast<float>(K) / GridDim::maxThreadsPerBlock));
for (int64_t i = 0; i < N; i++) {
FillInput<T><<<blocksPerGridD, GridDim::maxThreadsPerBlock, 0>>>(input_x, input_key, input_value, elem_nums, size, axis, K, i, dimension);
CUDA_RETURN_IF_ERROR(1 == largest ? cub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension) : cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension));
if (1 == sorted) {
FillOutput<T><<<blocksPerGridK, GridDim::maxThreadsPerBlock, 0>>>(output_key, output_value, output_v, output_i, elem_nums, size, axis, K, i, dimension);
} else { //reorder by ascending index
ExcludeOutput<<<blocksPerGridD, GridDim::maxThreadsPerBlock, 0>>>(output_value, K, dimension);
CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, output_value, input_value, output_key, input_key, dimension));
FillOutput<T><<<blocksPerGridK, GridDim::maxThreadsPerBlock, 0>>>(input_key, input_value, output_v, output_i, elem_nums, size, axis, K, i, dimension);
}
}
return Status::OK();
}
#define TOPKIMPLE(T) template Status TopKImpl<T>(const CudaKernel* kernel, \
const T* input_x, \
T* output_v, \
int64_t* output_i, \
const int64_t* elem_nums, \
size_t size, \
int64_t axis, \
int64_t K, \
int64_t largest, \
int64_t sorted, \
int64_t N, \
int64_t dimension)
TOPKIMPLE(uint8_t);
TOPKIMPLE(uint16_t);
TOPKIMPLE(uint32_t);
TOPKIMPLE(uint64_t);
TOPKIMPLE(int8_t);
TOPKIMPLE(int16_t);
TOPKIMPLE(int32_t);
TOPKIMPLE(int64_t);
TOPKIMPLE(float);
TOPKIMPLE(double);
} // namespace cuda
} // namespace onnxruntime

View file

@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <stdint.h>
#include "core/providers/cuda/shared_inc/cuda_utils.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/common/common.h"
namespace onnxruntime {
namespace cuda {
template <typename T>
Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension);
} // namespace cuda
} // namespace onnxruntime