diff --git a/onnxruntime/core/providers/cuda/math/topk_impl.cu b/onnxruntime/core/providers/cuda/math/topk_impl.cu index ba2d5c6e3e..3dc4852695 100644 --- a/onnxruntime/core/providers/cuda/math/topk_impl.cu +++ b/onnxruntime/core/providers/cuda/math/topk_impl.cu @@ -3,12 +3,344 @@ #include "topk_impl.h" #include "core/providers/cuda/cu_inc/common.cuh" +#include "device_atomic_functions.h" #include "cub/cub.cuh" +#include "cub/util_type.cuh" +#include "cub/util_allocator.cuh" +#include "cub/device/device_radix_sort.cuh" #include namespace onnxruntime { namespace cuda { +using namespace cub; + +template +struct KV { + T key; + int64_t val; +}; + +#define BT GridDim::maxThreadsPerBlock +#define ALIGN(N) static_cast(pow(2, ceil(log2(static_cast(N))))) +#define FROM(idx) (left_dim + (idx)*mid_dim + right_dim) +#define TO(idx) (left_dim * K / dimension + (idx)*mid_dim + right_dim) +#define TRIVIAL (1 == largest ? type_min : type_max) +#define BIGGER(n, m) (n.key > m.key ? n : (n.key < m.key ? m : (n.val > m.val ? (1 == largest ? m : n) : (1 == largest ? n : m)))) +#define SMALLER(n, m) (n.key < m.key ? n : (n.key > m.key ? m : (n.val < m.val ? (1 == largest ? m : n) : (1 == largest ? n : m)))) +#define IS_SMALLER(n, m) (n.key < m.key || !(n.key > m.key) && (1 == largest ? n.val > m.val : n.val < m.val)) +#define LESS(n, m) ((n) <= (m) ? (n) : (m)) + +template +__global__ void BitonicTopK(const T* X, T* V, int64_t* I, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t aligned_K, int64_t largest, int64_t sorted, int64_t dimension, int64_t aligned_dimension, T type_min, T type_max) { + auto tid = threadIdx.x; + auto bid = blockIdx.x; + extern __shared__ char shared_mem[]; + auto S = (KV*)(shared_mem); + auto mid_dim = axis == size - 1 ? 1 : elem_nums[axis + 1]; + auto left_dim = bid / mid_dim * elem_nums[axis]; + auto right_dim = axis == size - 1 ? 0 : bid % elem_nums[axis + 1]; + for (auto i = tid; i < aligned_dimension; i += blockDim.x) { + S[i].key = i < dimension ? X[FROM(i)] : TRIVIAL; + S[i].val = i; + } + __syncthreads(); + //sort each K + for (int64_t len = 1; len < aligned_K; len <<= 1) { + auto dir = len << 1; + for (auto inc = len; inc > 0; inc >>= 1) { + auto low = tid & (inc - 1); + auto i = (tid << 1) - low; + auto j = i + inc; + if (j < aligned_dimension) { + auto reverse = (dir & i) == 0; + auto swap = reverse ^ IS_SMALLER(S[i], S[j]); + if (swap) { + auto tmp = S[i]; + S[i] = S[j]; + S[j] = tmp; + } + } + __syncthreads(); + } + __syncthreads(); + } + //merge and rebuild K + for (int64_t len = aligned_K; len < aligned_dimension; len <<= 1) { + auto dir = len << 1; + auto i = (tid << 1) - (tid & (len - 1)); + auto j = i + len; + if (i % dir < aligned_K && j < aligned_dimension) { + S[i] = 1 == largest ? BIGGER(S[i], S[j]) : SMALLER(S[i], S[j]); + } + __syncthreads(); + for (auto inc = aligned_K >> 1; inc > 0; inc >>= 1) { + auto ii = (tid << 1) - (tid & (inc - 1)); + auto jj = ii + inc; + if (ii % dir < aligned_K && jj < aligned_dimension) { + auto reverse = (dir & ii) == 0; + auto swap = reverse ^ IS_SMALLER(S[ii], S[jj]); + if (swap) { + auto tmp = S[ii]; + S[ii] = S[jj]; + S[jj] = tmp; + } + } + __syncthreads(); + } + __syncthreads(); + } + //save top K + if (1 == sorted) { + if (1 == largest) { + auto start = aligned_K - K; + if (tid >= start && tid < aligned_K) { + auto to = TO(aligned_K - 1 - tid); + V[to] = S[tid].key; + I[to] = S[tid].val; + } + } else { + if (tid < K) { + auto to = TO(tid); + V[to] = S[tid].key; + I[to] = S[tid].val; + } + } + } else { + if (1 == largest) { + auto start = aligned_K - K; + if (tid < start) { + S[tid].val = aligned_dimension; + } + } else { + if (tid >= K && tid < aligned_K) { + S[tid].val = aligned_dimension; + } + } + __syncthreads(); + //sort by index ascending + for (int64_t len = 1; len < aligned_K; len <<= 1) { + auto dir = len << 1; + for (int64_t inc = len; inc > 0; inc >>= 1) { + auto low = tid & (inc - 1); + auto i = (tid << 1) - low; + auto j = i + inc; + if (j < aligned_K) { + auto reverse = (dir & i) == 0; + auto swap = reverse ^ (S[i].val < S[j].val); + if (swap) { + auto tmp = S[i]; + S[i] = S[j]; + S[j] = tmp; + } + } + __syncthreads(); + } + __syncthreads(); + } + if (tid < K) { + auto to = TO(tid); + V[to] = S[tid].key; + I[to] = S[tid].val; + } + } +} + +template +__device__ __inline__ bool Equal(const T& t0, const T& t1) { + auto t2 = t0 > t1 ? t0 - t1 : t1 - t0; + return (double)t2 < 1.0e-5; +} + +template +__device__ bool SamePrefix(const T* t0, const T* t1, int64_t skip) { + return (((*t0)^(*t1))>>skip) == 0; +} + +__device__ bool SamePrefix(const float* f0, const float* f1, int64_t skip) { + return SamePrefix((const int32_t*)f0, (const int32_t*)f1, skip); +} + +__device__ bool SamePrefix(const double* d0, const double* d1, int64_t skip) { + return SamePrefix((const int64_t*)d0, (const int64_t*)d1, skip); +} + +template +__device__ int32_t Radix(const T* t, int64_t skip) { + return ((*t)>>skip)&255; +} + +__device__ int32_t Radix(const float* f, int64_t skip) { + return Radix((const int32_t*)f, skip); +} + +__device__ int32_t Radix(const double* d, int64_t skip) { + return Radix((const double*)d, skip); +} + +template +__device__ void SetByte(T* t, int64_t byte) { + (*t) |= byte; +} + +__device__ void SetByte(float* f, int64_t byte) { + SetByte((int32_t*)f, byte); +} + +__device__ void SetByte(double* d, int64_t byte) { + SetByte((int64_t*)d, byte); +} + +template +__global__ void RadixTopK(const T* X, T* V, int64_t* I, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t dimension, int64_t XPT, T type_min, T type_max) { + auto tid = threadIdx.x; + auto bid = blockIdx.x; + extern __shared__ char shared_mem[]; + auto H = (uint32_t*)shared_mem; + auto mid_dim = axis == size - 1 ? 1 : elem_nums[axis + 1]; + auto left_dim = bid / mid_dim * elem_nums[axis]; + auto right_dim = axis == size - 1 ? 0 : bid % elem_nums[axis + 1]; + T Kth = (T)0, sign = (T)1; + typedef BlockScan BlockScan; + typedef BlockReduce BlockReduce; + typedef BlockRadixSort BlockRadixSort; + __shared__ union { + typename BlockScan::TempStorage scan; + typename BlockReduce::TempStorage reduce; + typename BlockRadixSort::TempStorage sort; + } temp_storage; + uint32_t positive = 0, negative = 0; + for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) { + T x = X[FROM(x_i)]; + if (x > 0) { + ++positive; + } else if (x < 0) { + ++negative; + } + } + __syncthreads(); + positive = BlockReduce(temp_storage.reduce).Sum(positive); + negative = BlockReduce(temp_storage.reduce).Sum(negative); + if (0 == tid) { + H[0] = positive; + H[1] = negative; + } + __syncthreads(); + positive = H[0]; + negative = H[1]; + if ((1 == largest && (K <= positive || dimension - K + 1 <= negative)) || + (0 == largest && (K <= negative || dimension - K + 1 <= positive))) { + auto KK = K; + if (1 == largest) { + if (KK > positive) { + KK = dimension - KK + 1; + sign = (T)-1; + } + } else { + if (KK > negative) { + KK = dimension - KK + 1; + } else { + sign = (T)-1; + } + } + __syncthreads(); + #pragma unroll + for (int64_t byte = sizeof(T)-1; byte > -1; --byte) { + if (tid < 256) H[tid] = 0; + __syncthreads(); + auto skip = 8 * byte, prev_skip = 8 * (byte + 1); + for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) { + T x = sign*X[FROM(x_i)]; + if (x > 0 && (byte == sizeof(T) - 1 || SamePrefix(&x, &Kth, prev_skip))) { + atomicAdd(&H[Radix(&x, skip)], 1); + } + } + __syncthreads(); + for (int64_t radix = 255; radix > 0; --radix) { + if (H[radix] < KK) { + KK -= H[radix]; + } else { + SetByte(&Kth, radix< Kth || 0 == largest && x < Kth) { + ++superior; + } else if (Equal(x, Kth)) { + ++equal; + } + } + __syncthreads(); + auto all_superior = superior; + all_superior = BlockReduce(temp_storage.reduce).Sum(all_superior); + if (0 == tid) { + H[0] = all_superior; + } + __syncthreads(); + all_superior = H[0]; + BlockScan(temp_storage.scan).ExclusiveSum(superior, superior); + BlockScan(temp_storage.scan).ExclusiveSum(equal, equal); + __syncthreads(); + auto equal_quota = K - all_superior - equal; + auto output_i = superior + LESS(K - all_superior, equal); + for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) { + auto x = X[FROM(x_i)]; + if (1 == largest && x > Kth || 0 == largest && x < Kth) { + auto to_i = TO(output_i); + V[to_i] = x; + I[to_i] = x_i; + ++output_i; + } else if (Equal(x, Kth) && equal_quota > 0) { + auto to_i = TO(output_i); + V[to_i] = x; + I[to_i] = x_i; + ++output_i; + --equal_quota; + } + } + __syncthreads(); + if (1 == sorted) { + T keys[KPT]; + int64_t vals[KPT]; + for (int64_t k_i = tid, k_c = 0; k_c < KPT; k_i += blockDim.x, ++k_c) { + if (k_i < K) { + auto to_i = TO(k_i); + keys[k_c] = V[to_i]; + vals[k_c] = I[to_i]; + } else { + if (1 == largest) { + keys[k_c] = type_min; + } else { + keys[k_c] = type_max; + } + } + } + __syncthreads(); + if (1 == largest) { + BlockRadixSort(temp_storage.sort).SortDescending(keys, vals); + } else { + BlockRadixSort(temp_storage.sort).Sort(keys, vals); + } + __syncthreads(); + #pragma unroll + for (int64_t k_c = 0; k_c < KPT; ++k_c) { + auto k_i = tid * KPT + k_c; + if (k_i < K) { + auto to_i = TO(k_i); + V[to_i] = keys[k_c]; + I[to_i] = vals[k_c]; + } + } + } +} + template __global__ void FillInput(const T* input_x, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t offset, int64_t dimension) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, dimension); @@ -38,30 +370,47 @@ __global__ void ExcludeOutput(int64_t* output_i, int64_t K, int64_t dimension) { template Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension) { - auto input_key_buffer = kernel->GetScratchBuffer(dimension); - auto output_key_buffer = kernel->GetScratchBuffer(dimension); - auto input_value_buffer = kernel->GetScratchBuffer(dimension); - auto output_value_buffer = kernel->GetScratchBuffer(dimension); - auto input_key = input_key_buffer.get(); - auto output_key = output_key_buffer.get(); - auto input_value = input_value_buffer.get(); - auto output_value = output_value_buffer.get(); - size_t temp_bytes = 0; - CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(nullptr, temp_bytes, input_key, output_key, input_value, output_value, dimension)); - auto temp_storage_buffer = kernel->GetScratchBuffer(temp_bytes); - auto temp_storage = temp_storage_buffer.get(); - auto blocksPerGridD = (int)(ceil(static_cast(dimension) / GridDim::maxThreadsPerBlock)); - auto blocksPerGridK = (int)(ceil(static_cast(K) / GridDim::maxThreadsPerBlock)); - for (int64_t i = 0; i < N; i++) { - FillInput<<>>(input_x, input_key, input_value, elem_nums, size, axis, K, i, dimension); - CUDA_RETURN_IF_ERROR(1 == largest ? cub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension) : cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension)); - if (1 == sorted) { - FillOutput<<>>(output_key, output_value, output_v, output_i, elem_nums, size, axis, K, i, dimension); - } else { //reorder by ascending index - ExcludeOutput<<>>(output_value, K, dimension); - CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, output_value, input_value, output_key, input_key, dimension)); - FillOutput<<>>(input_key, input_value, output_v, output_i, elem_nums, size, axis, K, i, dimension); + auto aligned_K = ALIGN(K); + auto aligned_dimension = ALIGN(dimension); + if (aligned_dimension <= GridDim::maxThreadsPerBlock << 1) { + BitonicTopK<<)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, aligned_K, largest, sorted, dimension, aligned_dimension, std::numeric_limits::lowest(), std::numeric_limits::max()); + } else if (K <= BT*16 || 0 == sorted) { + auto XPT = static_cast(ceil(static_cast(dimension) / GridDim::maxThreadsPerBlock)); + if (BT*2 >= K || 0 == sorted) { + RadixTopK<<>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits::lowest(), std::numeric_limits::max()); + } else if (BT*4>=K) { + RadixTopK<<>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits::lowest(), std::numeric_limits::max()); + } else if (BT*8>=K) { + RadixTopK<<>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits::lowest(), std::numeric_limits::max()); + } else { + RadixTopK<<>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits::lowest(), std::numeric_limits::max()); } + } else { + auto input_key_buffer = kernel->GetScratchBuffer(dimension); + auto output_key_buffer = kernel->GetScratchBuffer(dimension); + auto input_value_buffer = kernel->GetScratchBuffer(dimension); + auto output_value_buffer = kernel->GetScratchBuffer(dimension); + auto* input_key = input_key_buffer.get(); + auto* output_key = output_key_buffer.get(); + auto* input_value = input_value_buffer.get(); + auto* output_value = output_value_buffer.get(); + size_t temp_bytes = 0; + CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(nullptr, temp_bytes, input_key, output_key, input_value, output_value, dimension)); + auto temp_storage_buffer = kernel->GetScratchBuffer(temp_bytes); + auto* temp_storage = temp_storage_buffer.get(); + auto blocks_per_grid_D = (int)(ceil(static_cast(dimension) / BT)); + auto blocks_per_grid_K = (int)(ceil(static_cast(K) / BT)); + for (int64_t i = 0; i < N; i++) { + FillInput<<>>(input_x, input_key, input_value, elem_nums, size, axis, K, i, dimension); + CUDA_RETURN_IF_ERROR(1 == largest ? cub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension) : cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension)); + if (1 == sorted) { + FillOutput<<>>(output_key, output_value, output_v, output_i, elem_nums, size, axis, K, i, dimension); + } else { //reorder by ascending index + ExcludeOutput<<>>(output_value, K, dimension); + CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, output_value, input_value, output_key, input_key, dimension)); + FillOutput<<>>(input_key, input_value, output_v, output_i, elem_nums, size, axis, K, i, dimension); + } + } } return Status::OK(); } @@ -91,4 +440,4 @@ TOPKIMPLE(float); TOPKIMPLE(double); } // namespace cuda -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/topk_impl.h b/onnxruntime/core/providers/cuda/math/topk_impl.h index ddb66606c2..e28f98f1b4 100644 --- a/onnxruntime/core/providers/cuda/math/topk_impl.h +++ b/onnxruntime/core/providers/cuda/math/topk_impl.h @@ -14,4 +14,4 @@ template Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension); } // namespace cuda -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/math/topk_op_test.cc b/onnxruntime/test/providers/cpu/math/topk_op_test.cc index 32d1f24f24..97b493ce8f 100644 --- a/onnxruntime/test/providers/cpu/math/topk_op_test.cc +++ b/onnxruntime/test/providers/cpu/math/topk_op_test.cc @@ -475,5 +475,50 @@ TEST(TopKOperator, SortedSelection) { RunTest(11, 5, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, axis, 0); // smallest values } +TEST(TopKOperator, MediumArrayTopKSorted) +{ + std::vector input_vals(1000, 0.0f); + std::iota(input_vals.begin(), input_vals.end(), 0.0f); + std::vector input_dimensions = {1000}; + std::vector expected_vals(100, 0.0f); + std::iota(expected_vals.begin(), expected_vals.end(), 900.0f); + std::reverse(expected_vals.begin(), expected_vals.end()); + std::vector expected_indices(100, 0); + std::iota(expected_indices.begin(), expected_indices.end(), 900); + std::reverse(expected_indices.begin(), expected_indices.end()); + std::vector expected_dimensions = {100}; + RunTest(11, 100, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, 0, 1, 1); +} + +TEST(TopKOperator, BigArrayTopKSorted) +{ + std::vector input_vals(10000, 0.0f); + std::iota(input_vals.begin(), input_vals.end(), 0.0f); + std::vector input_dimensions = {10000}; + std::vector expected_vals(1000, 0.0f); + std::iota(expected_vals.begin(), expected_vals.end(), 9000.0f); + std::reverse(expected_vals.begin(), expected_vals.end()); + std::vector expected_indices(1000, 0); + std::iota(expected_indices.begin(), expected_indices.end(), 9000); + std::reverse(expected_indices.begin(), expected_indices.end()); + std::vector expected_dimensions = {1000}; + RunTest(11, 1000, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, 0, 1, 1); +} + +TEST(TopKOperator, BigArrayBigTopKSorted) +{ + std::vector input_vals(10000, 0.0f); + std::iota(input_vals.begin(), input_vals.end(), 0.0f); + std::vector input_dimensions = {10000}; + std::vector expected_vals(9000, 0.0f); + std::iota(expected_vals.begin(), expected_vals.end(), 1000.0f); + std::reverse(expected_vals.begin(), expected_vals.end()); + std::vector expected_indices(9000, 0); + std::iota(expected_indices.begin(), expected_indices.end(), 1000); + std::reverse(expected_indices.begin(), expected_indices.end()); + std::vector expected_dimensions = {9000}; + RunTest(11, 9000, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, 0, 1, 1); +} + } // namespace test } // namespace onnxruntime