mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
Rashuai/boost cuda TopK performance (#2826)
* Implement Bitonic and Radix TopK * remove needless print out * fix com err * add negative support * fix comments Co-authored-by: Randy <45701928+RandyShuai@users.noreply.github.com>
This commit is contained in:
parent
08113b80cc
commit
38b34babe0
3 changed files with 419 additions and 25 deletions
|
|
@ -3,12 +3,344 @@
|
|||
|
||||
#include "topk_impl.h"
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "device_atomic_functions.h"
|
||||
#include "cub/cub.cuh"
|
||||
#include "cub/util_type.cuh"
|
||||
#include "cub/util_allocator.cuh"
|
||||
#include "cub/device/device_radix_sort.cuh"
|
||||
#include <limits>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
using namespace cub;
|
||||
|
||||
template <typename T>
|
||||
struct KV {
|
||||
T key;
|
||||
int64_t val;
|
||||
};
|
||||
|
||||
#define BT GridDim::maxThreadsPerBlock
|
||||
#define ALIGN(N) static_cast<int64_t>(pow(2, ceil(log2(static_cast<double>(N)))))
|
||||
#define FROM(idx) (left_dim + (idx)*mid_dim + right_dim)
|
||||
#define TO(idx) (left_dim * K / dimension + (idx)*mid_dim + right_dim)
|
||||
#define TRIVIAL (1 == largest ? type_min : type_max)
|
||||
#define BIGGER(n, m) (n.key > m.key ? n : (n.key < m.key ? m : (n.val > m.val ? (1 == largest ? m : n) : (1 == largest ? n : m))))
|
||||
#define SMALLER(n, m) (n.key < m.key ? n : (n.key > m.key ? m : (n.val < m.val ? (1 == largest ? m : n) : (1 == largest ? n : m))))
|
||||
#define IS_SMALLER(n, m) (n.key < m.key || !(n.key > m.key) && (1 == largest ? n.val > m.val : n.val < m.val))
|
||||
#define LESS(n, m) ((n) <= (m) ? (n) : (m))
|
||||
|
||||
template <typename T>
|
||||
__global__ void BitonicTopK(const T* X, T* V, int64_t* I, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t aligned_K, int64_t largest, int64_t sorted, int64_t dimension, int64_t aligned_dimension, T type_min, T type_max) {
|
||||
auto tid = threadIdx.x;
|
||||
auto bid = blockIdx.x;
|
||||
extern __shared__ char shared_mem[];
|
||||
auto S = (KV<T>*)(shared_mem);
|
||||
auto mid_dim = axis == size - 1 ? 1 : elem_nums[axis + 1];
|
||||
auto left_dim = bid / mid_dim * elem_nums[axis];
|
||||
auto right_dim = axis == size - 1 ? 0 : bid % elem_nums[axis + 1];
|
||||
for (auto i = tid; i < aligned_dimension; i += blockDim.x) {
|
||||
S[i].key = i < dimension ? X[FROM(i)] : TRIVIAL;
|
||||
S[i].val = i;
|
||||
}
|
||||
__syncthreads();
|
||||
//sort each K
|
||||
for (int64_t len = 1; len < aligned_K; len <<= 1) {
|
||||
auto dir = len << 1;
|
||||
for (auto inc = len; inc > 0; inc >>= 1) {
|
||||
auto low = tid & (inc - 1);
|
||||
auto i = (tid << 1) - low;
|
||||
auto j = i + inc;
|
||||
if (j < aligned_dimension) {
|
||||
auto reverse = (dir & i) == 0;
|
||||
auto swap = reverse ^ IS_SMALLER(S[i], S[j]);
|
||||
if (swap) {
|
||||
auto tmp = S[i];
|
||||
S[i] = S[j];
|
||||
S[j] = tmp;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
//merge and rebuild K
|
||||
for (int64_t len = aligned_K; len < aligned_dimension; len <<= 1) {
|
||||
auto dir = len << 1;
|
||||
auto i = (tid << 1) - (tid & (len - 1));
|
||||
auto j = i + len;
|
||||
if (i % dir < aligned_K && j < aligned_dimension) {
|
||||
S[i] = 1 == largest ? BIGGER(S[i], S[j]) : SMALLER(S[i], S[j]);
|
||||
}
|
||||
__syncthreads();
|
||||
for (auto inc = aligned_K >> 1; inc > 0; inc >>= 1) {
|
||||
auto ii = (tid << 1) - (tid & (inc - 1));
|
||||
auto jj = ii + inc;
|
||||
if (ii % dir < aligned_K && jj < aligned_dimension) {
|
||||
auto reverse = (dir & ii) == 0;
|
||||
auto swap = reverse ^ IS_SMALLER(S[ii], S[jj]);
|
||||
if (swap) {
|
||||
auto tmp = S[ii];
|
||||
S[ii] = S[jj];
|
||||
S[jj] = tmp;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
//save top K
|
||||
if (1 == sorted) {
|
||||
if (1 == largest) {
|
||||
auto start = aligned_K - K;
|
||||
if (tid >= start && tid < aligned_K) {
|
||||
auto to = TO(aligned_K - 1 - tid);
|
||||
V[to] = S[tid].key;
|
||||
I[to] = S[tid].val;
|
||||
}
|
||||
} else {
|
||||
if (tid < K) {
|
||||
auto to = TO(tid);
|
||||
V[to] = S[tid].key;
|
||||
I[to] = S[tid].val;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (1 == largest) {
|
||||
auto start = aligned_K - K;
|
||||
if (tid < start) {
|
||||
S[tid].val = aligned_dimension;
|
||||
}
|
||||
} else {
|
||||
if (tid >= K && tid < aligned_K) {
|
||||
S[tid].val = aligned_dimension;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
//sort by index ascending
|
||||
for (int64_t len = 1; len < aligned_K; len <<= 1) {
|
||||
auto dir = len << 1;
|
||||
for (int64_t inc = len; inc > 0; inc >>= 1) {
|
||||
auto low = tid & (inc - 1);
|
||||
auto i = (tid << 1) - low;
|
||||
auto j = i + inc;
|
||||
if (j < aligned_K) {
|
||||
auto reverse = (dir & i) == 0;
|
||||
auto swap = reverse ^ (S[i].val < S[j].val);
|
||||
if (swap) {
|
||||
auto tmp = S[i];
|
||||
S[i] = S[j];
|
||||
S[j] = tmp;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (tid < K) {
|
||||
auto to = TO(tid);
|
||||
V[to] = S[tid].key;
|
||||
I[to] = S[tid].val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __inline__ bool Equal(const T& t0, const T& t1) {
|
||||
auto t2 = t0 > t1 ? t0 - t1 : t1 - t0;
|
||||
return (double)t2 < 1.0e-5;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ bool SamePrefix(const T* t0, const T* t1, int64_t skip) {
|
||||
return (((*t0)^(*t1))>>skip) == 0;
|
||||
}
|
||||
|
||||
__device__ bool SamePrefix(const float* f0, const float* f1, int64_t skip) {
|
||||
return SamePrefix((const int32_t*)f0, (const int32_t*)f1, skip);
|
||||
}
|
||||
|
||||
__device__ bool SamePrefix(const double* d0, const double* d1, int64_t skip) {
|
||||
return SamePrefix((const int64_t*)d0, (const int64_t*)d1, skip);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ int32_t Radix(const T* t, int64_t skip) {
|
||||
return ((*t)>>skip)&255;
|
||||
}
|
||||
|
||||
__device__ int32_t Radix(const float* f, int64_t skip) {
|
||||
return Radix((const int32_t*)f, skip);
|
||||
}
|
||||
|
||||
__device__ int32_t Radix(const double* d, int64_t skip) {
|
||||
return Radix((const double*)d, skip);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ void SetByte(T* t, int64_t byte) {
|
||||
(*t) |= byte;
|
||||
}
|
||||
|
||||
__device__ void SetByte(float* f, int64_t byte) {
|
||||
SetByte((int32_t*)f, byte);
|
||||
}
|
||||
|
||||
__device__ void SetByte(double* d, int64_t byte) {
|
||||
SetByte((int64_t*)d, byte);
|
||||
}
|
||||
|
||||
template<typename T, int64_t THREADS, int64_t KPT>
|
||||
__global__ void RadixTopK(const T* X, T* V, int64_t* I, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t dimension, int64_t XPT, T type_min, T type_max) {
|
||||
auto tid = threadIdx.x;
|
||||
auto bid = blockIdx.x;
|
||||
extern __shared__ char shared_mem[];
|
||||
auto H = (uint32_t*)shared_mem;
|
||||
auto mid_dim = axis == size - 1 ? 1 : elem_nums[axis + 1];
|
||||
auto left_dim = bid / mid_dim * elem_nums[axis];
|
||||
auto right_dim = axis == size - 1 ? 0 : bid % elem_nums[axis + 1];
|
||||
T Kth = (T)0, sign = (T)1;
|
||||
typedef BlockScan<uint32_t, THREADS> BlockScan;
|
||||
typedef BlockReduce<uint32_t, THREADS> BlockReduce;
|
||||
typedef BlockRadixSort<T, THREADS, KPT, int64_t> BlockRadixSort;
|
||||
__shared__ union {
|
||||
typename BlockScan::TempStorage scan;
|
||||
typename BlockReduce::TempStorage reduce;
|
||||
typename BlockRadixSort::TempStorage sort;
|
||||
} temp_storage;
|
||||
uint32_t positive = 0, negative = 0;
|
||||
for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) {
|
||||
T x = X[FROM(x_i)];
|
||||
if (x > 0) {
|
||||
++positive;
|
||||
} else if (x < 0) {
|
||||
++negative;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
positive = BlockReduce(temp_storage.reduce).Sum(positive);
|
||||
negative = BlockReduce(temp_storage.reduce).Sum(negative);
|
||||
if (0 == tid) {
|
||||
H[0] = positive;
|
||||
H[1] = negative;
|
||||
}
|
||||
__syncthreads();
|
||||
positive = H[0];
|
||||
negative = H[1];
|
||||
if ((1 == largest && (K <= positive || dimension - K + 1 <= negative)) ||
|
||||
(0 == largest && (K <= negative || dimension - K + 1 <= positive))) {
|
||||
auto KK = K;
|
||||
if (1 == largest) {
|
||||
if (KK > positive) {
|
||||
KK = dimension - KK + 1;
|
||||
sign = (T)-1;
|
||||
}
|
||||
} else {
|
||||
if (KK > negative) {
|
||||
KK = dimension - KK + 1;
|
||||
} else {
|
||||
sign = (T)-1;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int64_t byte = sizeof(T)-1; byte > -1; --byte) {
|
||||
if (tid < 256) H[tid] = 0;
|
||||
__syncthreads();
|
||||
auto skip = 8 * byte, prev_skip = 8 * (byte + 1);
|
||||
for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) {
|
||||
T x = sign*X[FROM(x_i)];
|
||||
if (x > 0 && (byte == sizeof(T) - 1 || SamePrefix(&x, &Kth, prev_skip))) {
|
||||
atomicAdd(&H[Radix(&x, skip)], 1);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
for (int64_t radix = 255; radix > 0; --radix) {
|
||||
if (H[radix] < KK) {
|
||||
KK -= H[radix];
|
||||
} else {
|
||||
SetByte(&Kth, radix<<skip);
|
||||
break;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
Kth *= sign;
|
||||
}
|
||||
uint32_t superior = 0, equal = 0;
|
||||
for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) {
|
||||
auto x = X[FROM(x_i)];
|
||||
if (1 == largest && x > Kth || 0 == largest && x < Kth) {
|
||||
++superior;
|
||||
} else if (Equal(x, Kth)) {
|
||||
++equal;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
auto all_superior = superior;
|
||||
all_superior = BlockReduce(temp_storage.reduce).Sum(all_superior);
|
||||
if (0 == tid) {
|
||||
H[0] = all_superior;
|
||||
}
|
||||
__syncthreads();
|
||||
all_superior = H[0];
|
||||
BlockScan(temp_storage.scan).ExclusiveSum(superior, superior);
|
||||
BlockScan(temp_storage.scan).ExclusiveSum(equal, equal);
|
||||
__syncthreads();
|
||||
auto equal_quota = K - all_superior - equal;
|
||||
auto output_i = superior + LESS(K - all_superior, equal);
|
||||
for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) {
|
||||
auto x = X[FROM(x_i)];
|
||||
if (1 == largest && x > Kth || 0 == largest && x < Kth) {
|
||||
auto to_i = TO(output_i);
|
||||
V[to_i] = x;
|
||||
I[to_i] = x_i;
|
||||
++output_i;
|
||||
} else if (Equal(x, Kth) && equal_quota > 0) {
|
||||
auto to_i = TO(output_i);
|
||||
V[to_i] = x;
|
||||
I[to_i] = x_i;
|
||||
++output_i;
|
||||
--equal_quota;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (1 == sorted) {
|
||||
T keys[KPT];
|
||||
int64_t vals[KPT];
|
||||
for (int64_t k_i = tid, k_c = 0; k_c < KPT; k_i += blockDim.x, ++k_c) {
|
||||
if (k_i < K) {
|
||||
auto to_i = TO(k_i);
|
||||
keys[k_c] = V[to_i];
|
||||
vals[k_c] = I[to_i];
|
||||
} else {
|
||||
if (1 == largest) {
|
||||
keys[k_c] = type_min;
|
||||
} else {
|
||||
keys[k_c] = type_max;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (1 == largest) {
|
||||
BlockRadixSort(temp_storage.sort).SortDescending(keys, vals);
|
||||
} else {
|
||||
BlockRadixSort(temp_storage.sort).Sort(keys, vals);
|
||||
}
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int64_t k_c = 0; k_c < KPT; ++k_c) {
|
||||
auto k_i = tid * KPT + k_c;
|
||||
if (k_i < K) {
|
||||
auto to_i = TO(k_i);
|
||||
V[to_i] = keys[k_c];
|
||||
I[to_i] = vals[k_c];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void FillInput(const T* input_x, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t offset, int64_t dimension) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, dimension);
|
||||
|
|
@ -38,30 +370,47 @@ __global__ void ExcludeOutput(int64_t* output_i, int64_t K, int64_t dimension) {
|
|||
|
||||
template <typename T>
|
||||
Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension) {
|
||||
auto input_key_buffer = kernel->GetScratchBuffer<T>(dimension);
|
||||
auto output_key_buffer = kernel->GetScratchBuffer<T>(dimension);
|
||||
auto input_value_buffer = kernel->GetScratchBuffer<int64_t>(dimension);
|
||||
auto output_value_buffer = kernel->GetScratchBuffer<int64_t>(dimension);
|
||||
auto input_key = input_key_buffer.get();
|
||||
auto output_key = output_key_buffer.get();
|
||||
auto input_value = input_value_buffer.get();
|
||||
auto output_value = output_value_buffer.get();
|
||||
size_t temp_bytes = 0;
|
||||
CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(nullptr, temp_bytes, input_key, output_key, input_value, output_value, dimension));
|
||||
auto temp_storage_buffer = kernel->GetScratchBuffer<char>(temp_bytes);
|
||||
auto temp_storage = temp_storage_buffer.get();
|
||||
auto blocksPerGridD = (int)(ceil(static_cast<float>(dimension) / GridDim::maxThreadsPerBlock));
|
||||
auto blocksPerGridK = (int)(ceil(static_cast<float>(K) / GridDim::maxThreadsPerBlock));
|
||||
for (int64_t i = 0; i < N; i++) {
|
||||
FillInput<T><<<blocksPerGridD, GridDim::maxThreadsPerBlock, 0>>>(input_x, input_key, input_value, elem_nums, size, axis, K, i, dimension);
|
||||
CUDA_RETURN_IF_ERROR(1 == largest ? cub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension) : cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension));
|
||||
if (1 == sorted) {
|
||||
FillOutput<T><<<blocksPerGridK, GridDim::maxThreadsPerBlock, 0>>>(output_key, output_value, output_v, output_i, elem_nums, size, axis, K, i, dimension);
|
||||
} else { //reorder by ascending index
|
||||
ExcludeOutput<<<blocksPerGridD, GridDim::maxThreadsPerBlock, 0>>>(output_value, K, dimension);
|
||||
CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, output_value, input_value, output_key, input_key, dimension));
|
||||
FillOutput<T><<<blocksPerGridK, GridDim::maxThreadsPerBlock, 0>>>(input_key, input_value, output_v, output_i, elem_nums, size, axis, K, i, dimension);
|
||||
auto aligned_K = ALIGN(K);
|
||||
auto aligned_dimension = ALIGN(dimension);
|
||||
if (aligned_dimension <= GridDim::maxThreadsPerBlock << 1) {
|
||||
BitonicTopK<T><<<N, GridDim::maxThreadsPerBlock, aligned_dimension * sizeof(KV<T>)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, aligned_K, largest, sorted, dimension, aligned_dimension, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
|
||||
} else if (K <= BT*16 || 0 == sorted) {
|
||||
auto XPT = static_cast<int64_t>(ceil(static_cast<double>(dimension) / GridDim::maxThreadsPerBlock));
|
||||
if (BT*2 >= K || 0 == sorted) {
|
||||
RadixTopK<T,BT,2><<<N,BT,256*sizeof(uint32_t)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
|
||||
} else if (BT*4>=K) {
|
||||
RadixTopK<T,BT,4><<<N,BT,256*sizeof(uint32_t)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
|
||||
} else if (BT*8>=K) {
|
||||
RadixTopK<T,BT,8><<<N,BT,256*sizeof(uint32_t)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
|
||||
} else {
|
||||
RadixTopK<T,BT,16><<<N,BT,256*sizeof(uint32_t)>>>(input_x, output_v, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
|
||||
}
|
||||
} else {
|
||||
auto input_key_buffer = kernel->GetScratchBuffer<T>(dimension);
|
||||
auto output_key_buffer = kernel->GetScratchBuffer<T>(dimension);
|
||||
auto input_value_buffer = kernel->GetScratchBuffer<int64_t>(dimension);
|
||||
auto output_value_buffer = kernel->GetScratchBuffer<int64_t>(dimension);
|
||||
auto* input_key = input_key_buffer.get();
|
||||
auto* output_key = output_key_buffer.get();
|
||||
auto* input_value = input_value_buffer.get();
|
||||
auto* output_value = output_value_buffer.get();
|
||||
size_t temp_bytes = 0;
|
||||
CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(nullptr, temp_bytes, input_key, output_key, input_value, output_value, dimension));
|
||||
auto temp_storage_buffer = kernel->GetScratchBuffer<char>(temp_bytes);
|
||||
auto* temp_storage = temp_storage_buffer.get();
|
||||
auto blocks_per_grid_D = (int)(ceil(static_cast<float>(dimension) / BT));
|
||||
auto blocks_per_grid_K = (int)(ceil(static_cast<float>(K) / BT));
|
||||
for (int64_t i = 0; i < N; i++) {
|
||||
FillInput<T><<<blocks_per_grid_D, BT, 0>>>(input_x, input_key, input_value, elem_nums, size, axis, K, i, dimension);
|
||||
CUDA_RETURN_IF_ERROR(1 == largest ? cub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension) : cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension));
|
||||
if (1 == sorted) {
|
||||
FillOutput<T><<<blocks_per_grid_K, BT, 0>>>(output_key, output_value, output_v, output_i, elem_nums, size, axis, K, i, dimension);
|
||||
} else { //reorder by ascending index
|
||||
ExcludeOutput<<<blocks_per_grid_D, BT, 0>>>(output_value, K, dimension);
|
||||
CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, output_value, input_value, output_key, input_key, dimension));
|
||||
FillOutput<T><<<blocks_per_grid_K, BT, 0>>>(input_key, input_value, output_v, output_i, elem_nums, size, axis, K, i, dimension);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -91,4 +440,4 @@ TOPKIMPLE(float);
|
|||
TOPKIMPLE(double);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -14,4 +14,4 @@ template <typename T>
|
|||
Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -475,5 +475,50 @@ TEST(TopKOperator, SortedSelection) {
|
|||
RunTest(11, 5, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, axis, 0); // smallest values
|
||||
}
|
||||
|
||||
TEST(TopKOperator, MediumArrayTopKSorted)
|
||||
{
|
||||
std::vector<float> input_vals(1000, 0.0f);
|
||||
std::iota(input_vals.begin(), input_vals.end(), 0.0f);
|
||||
std::vector<int64_t> input_dimensions = {1000};
|
||||
std::vector<float> expected_vals(100, 0.0f);
|
||||
std::iota(expected_vals.begin(), expected_vals.end(), 900.0f);
|
||||
std::reverse(expected_vals.begin(), expected_vals.end());
|
||||
std::vector<int64_t> expected_indices(100, 0);
|
||||
std::iota(expected_indices.begin(), expected_indices.end(), 900);
|
||||
std::reverse(expected_indices.begin(), expected_indices.end());
|
||||
std::vector<int64_t> expected_dimensions = {100};
|
||||
RunTest(11, 100, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, 0, 1, 1);
|
||||
}
|
||||
|
||||
TEST(TopKOperator, BigArrayTopKSorted)
|
||||
{
|
||||
std::vector<float> input_vals(10000, 0.0f);
|
||||
std::iota(input_vals.begin(), input_vals.end(), 0.0f);
|
||||
std::vector<int64_t> input_dimensions = {10000};
|
||||
std::vector<float> expected_vals(1000, 0.0f);
|
||||
std::iota(expected_vals.begin(), expected_vals.end(), 9000.0f);
|
||||
std::reverse(expected_vals.begin(), expected_vals.end());
|
||||
std::vector<int64_t> expected_indices(1000, 0);
|
||||
std::iota(expected_indices.begin(), expected_indices.end(), 9000);
|
||||
std::reverse(expected_indices.begin(), expected_indices.end());
|
||||
std::vector<int64_t> expected_dimensions = {1000};
|
||||
RunTest(11, 1000, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, 0, 1, 1);
|
||||
}
|
||||
|
||||
TEST(TopKOperator, BigArrayBigTopKSorted)
|
||||
{
|
||||
std::vector<float> input_vals(10000, 0.0f);
|
||||
std::iota(input_vals.begin(), input_vals.end(), 0.0f);
|
||||
std::vector<int64_t> input_dimensions = {10000};
|
||||
std::vector<float> expected_vals(9000, 0.0f);
|
||||
std::iota(expected_vals.begin(), expected_vals.end(), 1000.0f);
|
||||
std::reverse(expected_vals.begin(), expected_vals.end());
|
||||
std::vector<int64_t> expected_indices(9000, 0);
|
||||
std::iota(expected_indices.begin(), expected_indices.end(), 1000);
|
||||
std::reverse(expected_indices.begin(), expected_indices.end());
|
||||
std::vector<int64_t> expected_dimensions = {9000};
|
||||
RunTest(11, 9000, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, 0, 1, 1);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue