Rashuai/boost cuda TopK performance (#2826)

* Implement Bitonic and Radix TopK

* remove needless print out

* fix com err

* add negative support

* fix comments

Co-authored-by: Randy <45701928+RandyShuai@users.noreply.github.com>
This commit is contained in:
RandySheriffH 2020-01-21 13:40:38 -08:00 committed by GitHub
parent 08113b80cc
commit 38b34babe0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 419 additions and 25 deletions

View file

@ -3,12 +3,344 @@
#include "topk_impl.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "device_atomic_functions.h"
#include "cub/cub.cuh"
#include "cub/util_type.cuh"
#include "cub/util_allocator.cuh"
#include "cub/device/device_radix_sort.cuh"
#include <limits>
namespace onnxruntime {
namespace cuda {
using namespace cub;
template <typename T>
struct KV {
T key;
int64_t val;
};
#define BT GridDim::maxThreadsPerBlock
#define ALIGN(N) static_cast<int64_t>(pow(2, ceil(log2(static_cast<double>(N)))))
#define FROM(idx) (left_dim + (idx)*mid_dim + right_dim)
#define TO(idx) (left_dim * K / dimension + (idx)*mid_dim + right_dim)
#define TRIVIAL (1 == largest ? type_min : type_max)
#define BIGGER(n, m) (n.key > m.key ? n : (n.key < m.key ? m : (n.val > m.val ? (1 == largest ? m : n) : (1 == largest ? n : m))))
#define SMALLER(n, m) (n.key < m.key ? n : (n.key > m.key ? m : (n.val < m.val ? (1 == largest ? m : n) : (1 == largest ? n : m))))
#define IS_SMALLER(n, m) (n.key < m.key || !(n.key > m.key) && (1 == largest ? n.val > m.val : n.val < m.val))
#define LESS(n, m) ((n) <= (m) ? (n) : (m))
template <typename T>
__global__ void BitonicTopK(const T* X, T* V, int64_t* I, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t aligned_K, int64_t largest, int64_t sorted, int64_t dimension, int64_t aligned_dimension, T type_min, T type_max) {
auto tid = threadIdx.x;
auto bid = blockIdx.x;
extern __shared__ char shared_mem[];
auto S = (KV<T>*)(shared_mem);
auto mid_dim = axis == size - 1 ? 1 : elem_nums[axis + 1];
auto left_dim = bid / mid_dim * elem_nums[axis];
auto right_dim = axis == size - 1 ? 0 : bid % elem_nums[axis + 1];
for (auto i = tid; i < aligned_dimension; i += blockDim.x) {
S[i].key = i < dimension ? X[FROM(i)] : TRIVIAL;
S[i].val = i;
}
__syncthreads();
//sort each K
for (int64_t len = 1; len < aligned_K; len <<= 1) {
auto dir = len << 1;
for (auto inc = len; inc > 0; inc >>= 1) {
auto low = tid & (inc - 1);
auto i = (tid << 1) - low;
auto j = i + inc;
if (j < aligned_dimension) {
auto reverse = (dir & i) == 0;
auto swap = reverse ^ IS_SMALLER(S[i], S[j]);
if (swap) {
auto tmp = S[i];
S[i] = S[j];
S[j] = tmp;
}
}
__syncthreads();
}
__syncthreads();
}
//merge and rebuild K
for (int64_t len = aligned_K; len < aligned_dimension; len <<= 1) {
auto dir = len << 1;
auto i = (tid << 1) - (tid & (len - 1));
auto j = i + len;
if (i % dir < aligned_K && j < aligned_dimension) {
S[i] = 1 == largest ? BIGGER(S[i], S[j]) : SMALLER(S[i], S[j]);
}
__syncthreads();
for (auto inc = aligned_K >> 1; inc > 0; inc >>= 1) {
auto ii = (tid << 1) - (tid & (inc - 1));
auto jj = ii + inc;
if (ii % dir < aligned_K && jj < aligned_dimension) {
auto reverse = (dir & ii) == 0;
auto swap = reverse ^ IS_SMALLER(S[ii], S[jj]);
if (swap) {
auto tmp = S[ii];
S[ii] = S[jj];
S[jj] = tmp;
}
}
__syncthreads();
}
__syncthreads();
}
//save top K
if (1 == sorted) {
if (1 == largest) {
auto start = aligned_K - K;
if (tid >= start && tid < aligned_K) {
auto to = TO(aligned_K - 1 - tid);
V[to] = S[tid].key;
I[to] = S[tid].val;
}
} else {
if (tid < K) {
auto to = TO(tid);
V[to] = S[tid].key;
I[to] = S[tid].val;
}
}
} else {
if (1 == largest) {
auto start = aligned_K - K;
if (tid < start) {
S[tid].val = aligned_dimension;
}
} else {
if (tid >= K && tid < aligned_K) {
S[tid].val = aligned_dimension;
}
}
__syncthreads();
//sort by index ascending
for (int64_t len = 1; len < aligned_K; len <<= 1) {
auto dir = len << 1;
for (int64_t inc = len; inc > 0; inc >>= 1) {
auto low = tid & (inc - 1);
auto i = (tid << 1) - low;
auto j = i + inc;
if (j < aligned_K) {
auto reverse = (dir & i) == 0;
auto swap = reverse ^ (S[i].val < S[j].val);
if (swap) {
auto tmp = S[i];
S[i] = S[j];
S[j] = tmp;
}
}
__syncthreads();
}
__syncthreads();
}
if (tid < K) {
auto to = TO(tid);
V[to] = S[tid].key;
I[to] = S[tid].val;
}
}
}
template <typename T>
__device__ __inline__ bool Equal(const T& t0, const T& t1) {
auto t2 = t0 > t1 ? t0 - t1 : t1 - t0;
return (double)t2 < 1.0e-5;
}
template<typename T>
__device__ bool SamePrefix(const T* t0, const T* t1, int64_t skip) {
return (((*t0)^(*t1))>>skip) == 0;
}
__device__ bool SamePrefix(const float* f0, const float* f1, int64_t skip) {
return SamePrefix((const int32_t*)f0, (const int32_t*)f1, skip);
}
__device__ bool SamePrefix(const double* d0, const double* d1, int64_t skip) {
return SamePrefix((const int64_t*)d0, (const int64_t*)d1, skip);
}
template<typename T>
__device__ int32_t Radix(const T* t, int64_t skip) {
return ((*t)>>skip)&255;
}
__device__ int32_t Radix(const float* f, int64_t skip) {
return Radix((const int32_t*)f, skip);
}
__device__ int32_t Radix(const double* d, int64_t skip) {
return Radix((const double*)d, skip);
}
template<typename T>
__device__ void SetByte(T* t, int64_t byte) {
(*t) |= byte;
}
__device__ void SetByte(float* f, int64_t byte) {
SetByte((int32_t*)f, byte);
}
__device__ void SetByte(double* d, int64_t byte) {
SetByte((int64_t*)d, byte);
}
template<typename T, int64_t THREADS, int64_t KPT>
__global__ void RadixTopK(const T* X, T* V, int64_t* I, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t dimension, int64_t XPT, T type_min, T type_max) {
auto tid = threadIdx.x;
auto bid = blockIdx.x;
extern __shared__ char shared_mem[];
auto H = (uint32_t*)shared_mem;
auto mid_dim = axis == size - 1 ? 1 : elem_nums[axis + 1];
auto left_dim = bid / mid_dim * elem_nums[axis];
auto right_dim = axis == size - 1 ? 0 : bid % elem_nums[axis + 1];
T Kth = (T)0, sign = (T)1;
typedef BlockScan<uint32_t, THREADS> BlockScan;
typedef BlockReduce<uint32_t, THREADS> BlockReduce;
typedef BlockRadixSort<T, THREADS, KPT, int64_t> BlockRadixSort;
__shared__ union {
typename BlockScan::TempStorage scan;
typename BlockReduce::TempStorage reduce;
typename BlockRadixSort::TempStorage sort;
} temp_storage;
uint32_t positive = 0, negative = 0;
for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) {
T x = X[FROM(x_i)];
if (x > 0) {
++positive;
} else if (x < 0) {
++negative;
}
}
__syncthreads();
positive = BlockReduce(temp_storage.reduce).Sum(positive);
negative = BlockReduce(temp_storage.reduce).Sum(negative);
if (0 == tid) {
H[0] = positive;
H[1] = negative;
}
__syncthreads();
positive = H[0];
negative = H[1];
if ((1 == largest && (K <= positive || dimension - K + 1 <= negative)) ||
(0 == largest && (K <= negative || dimension - K + 1 <= positive))) {
auto KK = K;
if (1 == largest) {
if (KK > positive) {
KK = dimension - KK + 1;
sign = (T)-1;
}
} else {
if (KK > negative) {
KK = dimension - KK + 1;
} else {
sign = (T)-1;
}
}
__syncthreads();
#pragma unroll
for (int64_t byte = sizeof(T)-1; byte > -1; --byte) {
if (tid < 256) H[tid] = 0;
__syncthreads();
auto skip = 8 * byte, prev_skip = 8 * (byte + 1);
for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) {
T x = sign*X[FROM(x_i)];
if (x > 0 && (byte == sizeof(T) - 1 || SamePrefix(&x, &Kth, prev_skip))) {
atomicAdd(&H[Radix(&x, skip)], 1);
}
}
__syncthreads();
for (int64_t radix = 255; radix > 0; --radix) {
if (H[radix] < KK) {
KK -= H[radix];
} else {
SetByte(&Kth, radix<<skip);
break;
}
}
__syncthreads();
}
Kth *= sign;
}
uint32_t superior = 0, equal = 0;
for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) {
auto x = X[FROM(x_i)];
if (1 == largest && x > Kth || 0 == largest && x < Kth) {
++superior;
} else if (Equal(x, Kth)) {
++equal;
}
}
__syncthreads();
auto all_superior = superior;
all_superior = BlockReduce(temp_storage.reduce).Sum(all_superior);
if (0 == tid) {
H[0] = all_superior;
}
__syncthreads();
all_superior = H[0];
BlockScan(temp_storage.scan).ExclusiveSum(superior, superior);
BlockScan(temp_storage.scan).ExclusiveSum(equal, equal);
__syncthreads();
auto equal_quota = K - all_superior - equal;
auto output_i = superior + LESS(K - all_superior, equal);
for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) {
auto x = X[FROM(x_i)];
if (1 == largest && x > Kth || 0 == largest && x < Kth) {
auto to_i = TO(output_i);
V[to_i] = x;
I[to_i] = x_i;
++output_i;
} else if (Equal(x, Kth) && equal_quota > 0) {
auto to_i = TO(output_i);
V[to_i] = x;
I[to_i] = x_i;
++output_i;
--equal_quota;
}
}
__syncthreads();
if (1 == sorted) {
T keys[KPT];
int64_t vals[KPT];
for (int64_t k_i = tid, k_c = 0; k_c < KPT; k_i += blockDim.x, ++k_c) {
if (k_i < K) {
auto to_i = TO(k_i);
keys[k_c] = V[to_i];
vals[k_c] = I[to_i];
} else {
if (1 == largest) {
keys[k_c] = type_min;
} else {
keys[k_c] = type_max;
}
}
}
__syncthreads();
if (1 == largest) {
BlockRadixSort(temp_storage.sort).SortDescending(keys, vals);
} else {
BlockRadixSort(temp_storage.sort).Sort(keys, vals);
}
__syncthreads();
#pragma unroll
for (int64_t k_c = 0; k_c < KPT; ++k_c) {
auto k_i = tid * KPT + k_c;
if (k_i < K) {
auto to_i = TO(k_i);
V[to_i] = keys[k_c];
I[to_i] = vals[k_c];
}
}
}
}
template <typename T>
__global__ void FillInput(const T* input_x, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t offset, int64_t dimension) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, dimension);
@ -38,30 +370,47 @@ __global__ void ExcludeOutput(int64_t* output_i, int64_t K, int64_t dimension) {
template <typename T>
Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension) {
auto input_key_buffer = kernel->GetScratchBuffer<T>(dimension);
auto output_key_buffer = kernel->GetScratchBuffer<T>(dimension);
auto input_value_buffer = kernel->GetScratchBuffer<int64_t>(dimension);
auto output_value_buffer = kernel->GetScratchBuffer<int64_t>(dimension);
auto input_key = input_key_buffer.get();
auto output_key = output_key_buffer.get();
auto input_value = input_value_buffer.get();
auto output_value = output_value_buffer.get();
size_t temp_bytes = 0;
CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(nullptr, temp_bytes, input_key, output_key, input_value, output_value, dimension));
auto temp_storage_buffer = kernel->GetScratchBuffer<char>(temp_bytes);
auto temp_storage = temp_storage_buffer.get();
auto blocksPerGridD = (int)(ceil(static_cast<float>(dimension) / GridDim::maxThreadsPerBlock));
auto blocksPerGridK = (int)(ceil(static_cast<float>(K) / GridDim::maxThreadsPerBlock));
for (int64_t i = 0; i < N; i++) {
FillInput<T><<<blocksPerGridD, GridDim::maxThreadsPerBlock, 0>>>(input_x, input_key, input_value, elem_nums, size, axis, K, i, dimension);
CUDA_RETURN_IF_ERROR(1 == largest ? cub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension) : cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension));
if (1 == sorted) {
FillOutput<T><<<blocksPerGridK, GridDim::maxThreadsPerBlock, 0>>>(output_key, output_value, output_v, output_i, elem_nums, size, axis, K, i, dimension);
} else { //reorder by ascending index
ExcludeOutput<<<blocksPerGridD, GridDim::maxThreadsPerBlock, 0>>>(output_value, K, dimension);
CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, output_value, input_value, output_key, input_key, dimension));
FillOutput<T><<<blocksPerGridK, GridDim::maxThreadsPerBlock, 0>>>(input_key, input_value, output_v, output_i, elem_nums, size, axis, K, i, dimension);
auto aligned_K = ALIGN(K);
auto aligned_dimension = ALIGN(dimension);
if (aligned_dimension <= GridDim::maxThreadsPerBlock << 1) {
BitonicTopK<T><<<N, GridDim::maxThreadsPerBlock, aligned_dimension * sizeof(KV<T>)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, aligned_K, largest, sorted, dimension, aligned_dimension, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
} else if (K <= BT*16 || 0 == sorted) {
auto XPT = static_cast<int64_t>(ceil(static_cast<double>(dimension) / GridDim::maxThreadsPerBlock));
if (BT*2 >= K || 0 == sorted) {
RadixTopK<T,BT,2><<<N,BT,256*sizeof(uint32_t)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
} else if (BT*4>=K) {
RadixTopK<T,BT,4><<<N,BT,256*sizeof(uint32_t)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
} else if (BT*8>=K) {
RadixTopK<T,BT,8><<<N,BT,256*sizeof(uint32_t)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
} else {
RadixTopK<T,BT,16><<<N,BT,256*sizeof(uint32_t)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
}
} else {
auto input_key_buffer = kernel->GetScratchBuffer<T>(dimension);
auto output_key_buffer = kernel->GetScratchBuffer<T>(dimension);
auto input_value_buffer = kernel->GetScratchBuffer<int64_t>(dimension);
auto output_value_buffer = kernel->GetScratchBuffer<int64_t>(dimension);
auto* input_key = input_key_buffer.get();
auto* output_key = output_key_buffer.get();
auto* input_value = input_value_buffer.get();
auto* output_value = output_value_buffer.get();
size_t temp_bytes = 0;
CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(nullptr, temp_bytes, input_key, output_key, input_value, output_value, dimension));
auto temp_storage_buffer = kernel->GetScratchBuffer<char>(temp_bytes);
auto* temp_storage = temp_storage_buffer.get();
auto blocks_per_grid_D = (int)(ceil(static_cast<float>(dimension) / BT));
auto blocks_per_grid_K = (int)(ceil(static_cast<float>(K) / BT));
for (int64_t i = 0; i < N; i++) {
FillInput<T><<<blocks_per_grid_D, BT, 0>>>(input_x, input_key, input_value, elem_nums, size, axis, K, i, dimension);
CUDA_RETURN_IF_ERROR(1 == largest ? cub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension) : cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension));
if (1 == sorted) {
FillOutput<T><<<blocks_per_grid_K, BT, 0>>>(output_key, output_value, output_v, output_i, elem_nums, size, axis, K, i, dimension);
} else { //reorder by ascending index
ExcludeOutput<<<blocks_per_grid_D, BT, 0>>>(output_value, K, dimension);
CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, output_value, input_value, output_key, input_key, dimension));
FillOutput<T><<<blocks_per_grid_K, BT, 0>>>(input_key, input_value, output_v, output_i, elem_nums, size, axis, K, i, dimension);
}
}
}
return Status::OK();
}
@ -91,4 +440,4 @@ TOPKIMPLE(float);
TOPKIMPLE(double);
} // namespace cuda
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -14,4 +14,4 @@ template <typename T>
Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension);
} // namespace cuda
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -475,5 +475,50 @@ TEST(TopKOperator, SortedSelection) {
RunTest(11, 5, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, axis, 0); // smallest values
}
TEST(TopKOperator, MediumArrayTopKSorted)
{
std::vector<float> input_vals(1000, 0.0f);
std::iota(input_vals.begin(), input_vals.end(), 0.0f);
std::vector<int64_t> input_dimensions = {1000};
std::vector<float> expected_vals(100, 0.0f);
std::iota(expected_vals.begin(), expected_vals.end(), 900.0f);
std::reverse(expected_vals.begin(), expected_vals.end());
std::vector<int64_t> expected_indices(100, 0);
std::iota(expected_indices.begin(), expected_indices.end(), 900);
std::reverse(expected_indices.begin(), expected_indices.end());
std::vector<int64_t> expected_dimensions = {100};
RunTest(11, 100, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, 0, 1, 1);
}
TEST(TopKOperator, BigArrayTopKSorted)
{
std::vector<float> input_vals(10000, 0.0f);
std::iota(input_vals.begin(), input_vals.end(), 0.0f);
std::vector<int64_t> input_dimensions = {10000};
std::vector<float> expected_vals(1000, 0.0f);
std::iota(expected_vals.begin(), expected_vals.end(), 9000.0f);
std::reverse(expected_vals.begin(), expected_vals.end());
std::vector<int64_t> expected_indices(1000, 0);
std::iota(expected_indices.begin(), expected_indices.end(), 9000);
std::reverse(expected_indices.begin(), expected_indices.end());
std::vector<int64_t> expected_dimensions = {1000};
RunTest(11, 1000, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, 0, 1, 1);
}
TEST(TopKOperator, BigArrayBigTopKSorted)
{
std::vector<float> input_vals(10000, 0.0f);
std::iota(input_vals.begin(), input_vals.end(), 0.0f);
std::vector<int64_t> input_dimensions = {10000};
std::vector<float> expected_vals(9000, 0.0f);
std::iota(expected_vals.begin(), expected_vals.end(), 1000.0f);
std::reverse(expected_vals.begin(), expected_vals.end());
std::vector<int64_t> expected_indices(9000, 0);
std::iota(expected_indices.begin(), expected_indices.end(), 1000);
std::reverse(expected_indices.begin(), expected_indices.end());
std::vector<int64_t> expected_dimensions = {9000};
RunTest(11, 9000, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, 0, 1, 1);
}
} // namespace test
} // namespace onnxruntime