mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Improve TopK performance. (#3612)
* Update TopK implementation.
- add faster heap
- special case k=1
- update selector for when to use heap and when to use nth_element based on performance testing
- parallelize if enough work to do
- reduce templatized code
- add some extra unit tests.
Perf tested vs. master. Average speedup is 3.75x using this combination of input sizes:
```
batches = [10, 25, 50]
batch_size = [8, 16, 32, 64, 128, 256, 512, 1024, 2048]
k = [1, 2, 4, 6, 8, 16, 24, 32, 48, 64, 128]
```
For larger batches (e.g. 50x2048) the speedup is over 20x.
This commit is contained in:
parent
9636da3951
commit
b4508dbdc6
2 changed files with 327 additions and 154 deletions
|
|
@ -20,6 +20,7 @@
|
|||
#include "core/common/exceptions.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include <queue>
|
||||
#include <algorithm>
|
||||
|
|
@ -31,145 +32,301 @@ namespace onnxruntime {
|
|||
template <typename T>
|
||||
struct GreaterValueCmp {
|
||||
using DataType = T;
|
||||
bool operator()(const pair<T, int64_t>& lhs, const pair<T, int64_t>& rhs) {
|
||||
return (lhs.first > rhs.first ||
|
||||
GreaterValueCmp(const T* data = nullptr) : data_(data) {
|
||||
}
|
||||
|
||||
bool operator()(const int64_t lhs_idx, const int64_t rhs_idx) const {
|
||||
return (data_[lhs_idx] > data_[rhs_idx] ||
|
||||
// when values are equal, we want lhs to get higher "priority"
|
||||
// if its corresponding index comes first (i.e.) is lower
|
||||
(lhs.first == rhs.first && lhs.second < rhs.second));
|
||||
(data_[lhs_idx] == data_[rhs_idx] && lhs_idx < rhs_idx));
|
||||
}
|
||||
|
||||
bool CompareValueOnly(const T& lhs, const T& rhs) const {
|
||||
return lhs > rhs;
|
||||
}
|
||||
|
||||
private:
|
||||
const T* data_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LesserValueCmp {
|
||||
using DataType = T;
|
||||
bool operator()(const pair<T, int64_t>& lhs, const pair<T, int64_t>& rhs) {
|
||||
return (lhs.first < rhs.first ||
|
||||
|
||||
LesserValueCmp(const T* data = nullptr) : data_(data) {
|
||||
}
|
||||
|
||||
bool operator()(const int64_t lhs_idx, const int64_t rhs_idx) const {
|
||||
return (data_[lhs_idx] < data_[rhs_idx] ||
|
||||
// when values are equal, we want lhs to get higher "priority"
|
||||
// if its corresponding index comes first (i.e.) is lower
|
||||
(lhs.first == rhs.first && lhs.second < rhs.second));
|
||||
(data_[lhs_idx] == data_[rhs_idx] && lhs_idx < rhs_idx));
|
||||
}
|
||||
|
||||
bool CompareValueOnly(const T& lhs, const T& rhs) const {
|
||||
return lhs < rhs;
|
||||
}
|
||||
|
||||
private:
|
||||
const T* data_;
|
||||
};
|
||||
|
||||
/*
|
||||
Maintain a binary heap where HeapComp of the parent with either child is false.
|
||||
e.g. if the comparison is 'greater than', the parent is smaller than both children.
|
||||
There is no ordering within a level.
|
||||
|
||||
NOTE: The comparison is backwards compared to std::priority_queue as we use the same comparator for this as for
|
||||
nth_element in SelectTopK. As such for a heap selecting the largest values the comparator is 'greater than'.
|
||||
*/
|
||||
template <class HeapCmp>
|
||||
static void HeapifyIthPosition(int64_t* heap, size_t i, size_t k, const HeapCmp& heap_cmp) {
|
||||
while (true) {
|
||||
size_t left = 2 * i + 1;
|
||||
size_t right = left + 1;
|
||||
if (right < k) {
|
||||
// need to check both left and right children as either could be replaced
|
||||
|
||||
// check if we should move child up. check left node as well as whether left is preferred over right.
|
||||
// if 'i' can replace left, check whether right would replace left (if so, i replaces left as it's the weakest)
|
||||
bool i_replaces_left = heap_cmp(heap[i], heap[left]);
|
||||
if (i_replaces_left && heap_cmp(heap[right], heap[left])) {
|
||||
// left is going to be pushed up as both i and right beat it
|
||||
// NOTE: std::swap is slower as it uses std::move
|
||||
auto tmp = heap[i];
|
||||
heap[i] = heap[left];
|
||||
heap[left] = tmp;
|
||||
i = left;
|
||||
} else if (i_replaces_left || heap_cmp(heap[i], heap[right])) {
|
||||
// i_replaces_left implies left replaces right due to 'if' so replace right with i as right is the weakest.
|
||||
// also check if i only beats right
|
||||
auto tmp = heap[i];
|
||||
heap[i] = heap[right];
|
||||
heap[right] = tmp;
|
||||
i = right;
|
||||
} else
|
||||
break;
|
||||
} else if ((left < k) && heap_cmp(heap[i], heap[left])) {
|
||||
auto tmp = heap[i];
|
||||
heap[i] = heap[left];
|
||||
heap[left] = tmp;
|
||||
i = left;
|
||||
} else
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Static helpers that implement the core logic for each of the 'TopK' operator flavor
|
||||
|
||||
// Selects the top k elements (largest or smallest based on template parameter)
|
||||
template <class Comparator>
|
||||
static vector<pair<typename Comparator::DataType, int64_t>> select_top_k(
|
||||
const ConstEigenMatrixMapRowMajor<typename Comparator::DataType>& raw_data, int64_t row_num, int64_t num_blocks,
|
||||
int64_t block_slice, int64_t inter_block_offset, const unsigned k,
|
||||
bool sort_top_k) {
|
||||
// create a data holder and insert elements
|
||||
vector<pair<typename Comparator::DataType, int64_t>> data_holder;
|
||||
data_holder.reserve(num_blocks);
|
||||
static void SelectTopK(const Comparator& comparer,
|
||||
int64_t row_offset, int64_t num_blocks, int64_t block_slice, int64_t inter_block_offset,
|
||||
const unsigned k, bool sort_top_k, vector<int64_t>& data_holder) {
|
||||
for (int64_t l = 0; l < num_blocks; ++l) {
|
||||
data_holder.push_back({raw_data(row_num, l * block_slice + inter_block_offset), l});
|
||||
data_holder[l] = (row_offset + (l * block_slice + inter_block_offset));
|
||||
}
|
||||
|
||||
// find the top k (largest or smallest) elements in the data holder - O(n)
|
||||
nth_element(data_holder.begin(), data_holder.begin() + (k - 1), data_holder.end(), Comparator());
|
||||
// find the top k (largest or smallest) elements in the data holder - O(n) average. O(n*n) worst case.
|
||||
// See https://en.wikipedia.org/wiki/Quickselect
|
||||
nth_element(data_holder.begin(), data_holder.begin() + (k - 1), data_holder.end(), comparer);
|
||||
|
||||
// sort the top k elements if needed - O (k log k)
|
||||
if (sort_top_k) {
|
||||
std::sort(data_holder.begin(), data_holder.begin() + k, Comparator());
|
||||
std::sort(data_holder.begin(), data_holder.begin() + k, comparer);
|
||||
}
|
||||
|
||||
// the data_holder now contains the top k elements in the first k indices
|
||||
return data_holder;
|
||||
// the data_holder now contains the indices of the top k elements in the first k elements
|
||||
}
|
||||
|
||||
// Given an input tensor 'input' and metadata values - 'k' and 'axis_parsed',
|
||||
// this method will extract the sorted top k largest/smallest elements and place them in the output tensor 'values'
|
||||
// along with the metadata output 'indices'
|
||||
template <bool largest, bool sorted, class Comparator>
|
||||
static void extract_top_k_elements(const Tensor* input, const TensorShape& input_shape, Tensor* values,
|
||||
Tensor* indices, const TensorShape& output_shape, const unsigned k,
|
||||
const unsigned axis_parsed) {
|
||||
template <class Comparator>
|
||||
static void FindTopKElements(const Tensor* input, const TensorShape& input_shape, Tensor* values,
|
||||
Tensor* indices, const TensorShape& output_shape, const unsigned k, bool sorted,
|
||||
const unsigned axis_parsed, concurrency::ThreadPool* threadpool) {
|
||||
// Cache some values that will be used in the implementation below
|
||||
const int64_t rows = input_shape.SizeToDimension(static_cast<size_t>(axis_parsed));
|
||||
const int64_t cols = input->Shape().Size() / rows;
|
||||
auto input_map =
|
||||
ConstEigenMatrixMapRowMajor<typename Comparator::DataType>(
|
||||
static_cast<const typename Comparator::DataType*>(input->template Data<typename Comparator::DataType>()), rows, cols);
|
||||
const auto* input_data = input->template Data<typename Comparator::DataType>();
|
||||
|
||||
// Use Eigen maps to allow indexing into the 2d tensors like Values_map(i,j)
|
||||
// Use Eigen maps for convenient indexing into the 2d tensors like Values_map(i,j)
|
||||
const int64_t reduced_cols = output_shape.SizeFromDimension(static_cast<size_t>(axis_parsed));
|
||||
auto values_map = EigenMatrixMapRowMajor<typename Comparator::DataType>(
|
||||
values->template MutableData<typename Comparator::DataType>(), rows, reduced_cols);
|
||||
auto indices_map = EigenMatrixMapRowMajor<int64_t>(indices->template MutableData<int64_t>(), rows, reduced_cols);
|
||||
|
||||
auto* values_data = values->template MutableData<typename Comparator::DataType>();
|
||||
auto* indices_data = indices->template MutableData<int64_t>();
|
||||
auto values_map = EigenMatrixMapRowMajor<typename Comparator::DataType>(values_data, rows, reduced_cols);
|
||||
auto indices_map = EigenMatrixMapRowMajor<int64_t>(indices_data, rows, reduced_cols);
|
||||
|
||||
// This is basically the number of elements within each of the "k" rows
|
||||
const int64_t block_slice = reduced_cols / k;
|
||||
const int64_t num_blocks = input_shape[axis_parsed];
|
||||
const int64_t block_slice = reduced_cols / k;
|
||||
|
||||
for (int64_t i = 0; i < rows; ++i) {
|
||||
for (int64_t j = 0; j < block_slice; ++j) {
|
||||
// Since sorted == true, we will use a Heap to hold the top K values in sorted fashion
|
||||
if (sorted) { // The optimizer will clean-up the redundant condition based on the template parameter 'sorted'
|
||||
auto n_casted = static_cast<double>(num_blocks);
|
||||
auto k_casted = static_cast<double>(k);
|
||||
if ((n_casted + k_casted * log(k_casted)) < (n_casted * log(k_casted))) {
|
||||
// Select first - O(n), then sort O(k * ln(k))
|
||||
// Overall complexity = O (n + k * ln(k))
|
||||
const auto& data_holder = select_top_k<Comparator>(input_map, i, num_blocks, block_slice, j, k, true);
|
||||
for (int64_t l = 0; l < k; ++l) {
|
||||
const auto& elem = data_holder[l];
|
||||
auto col_index = l * block_slice + j;
|
||||
values_map(i, col_index) = elem.first;
|
||||
indices_map(i, col_index) = elem.second;
|
||||
}
|
||||
} else {
|
||||
// Perform sorted selection by passing 'n' elements over a heap of size 'k'
|
||||
// overall complexity = O (n * ln(k))
|
||||
int64_t tp_threads = threadpool != nullptr ? threadpool->NumThreads() : 1;
|
||||
int64_t num_threads = std::min(tp_threads, rows); // split on rows so can't have more threads than rows
|
||||
|
||||
// Build a min-heap/max-heap, the heap element is pair of (value, idx)
|
||||
// The top of the heap is the smallest/largest value depending on whether it is a min-heap/max-heap
|
||||
// This is a min-heap if largest == true, this is a max-heap if largest == false
|
||||
priority_queue<pair<typename Comparator::DataType, int64_t>, vector<pair<typename Comparator::DataType, int64_t>>, Comparator> heap;
|
||||
// rough attempt to make sure there's enough work for each thread. if there's insufficient work the usage of
|
||||
// too many threads degrades performance.
|
||||
// TODO: May want a different calculation for each branch below instead.
|
||||
int64_t threads_needed = static_cast<int64_t>(std::floor(input_shape.Size() * k / (128 * 1024)));
|
||||
num_threads = std::max(std::min(threads_needed, num_threads), static_cast<int64_t>(1));
|
||||
|
||||
// Maintain the size of heap to be less or equal to k, so the
|
||||
// heap will hold the k largest/smallest values
|
||||
for (int64_t l = 0; l < num_blocks; ++l) {
|
||||
const auto value = input_map(i, l * block_slice + j);
|
||||
// largest == true: insert into the min-heap if the size is < k or if the new
|
||||
// element is greater than the min element in the min-heap
|
||||
// from testing various batch sizes relative to k, the following appears to work well as a selector.
|
||||
// tested with following combinations
|
||||
// batch_size = [ 8, 16, 32, 64, 128, 256, 512, 1024, 2048 ]
|
||||
// k = [ 1, 2, 4, 6, 8, 16, 24, 32, 48, 64, 128 ]
|
||||
bool use_priority_queue = k != 1 && (k < 4 || (std::log2(k) / std::log2(num_blocks)) < 0.725);
|
||||
|
||||
// largest == false: insert into the min-heap if the size is < k or if the new
|
||||
// element is lesser than the max element in the max-heap
|
||||
if ((heap.size() < k) || (largest && value > heap.top().first) ||
|
||||
(!largest && value < heap.top().first)) { // the optimizer will clean-up the redundant condition based
|
||||
// on the template parameter 'largest'
|
||||
heap.push({value, l});
|
||||
}
|
||||
if (heap.size() > k) {
|
||||
heap.pop();
|
||||
std::function<void(std::ptrdiff_t batch)> find_top_k;
|
||||
|
||||
if (k == 1) {
|
||||
// just need to compare values and not indexes as the first instance of the best value is always selected
|
||||
find_top_k =
|
||||
[num_threads, rows, block_slice, num_blocks, input_data, cols, &values_map, &indices_map](std::ptrdiff_t batch) {
|
||||
int64_t start_row = static_cast<int64_t>(batch * rows / num_threads);
|
||||
int64_t end_row = static_cast<int64_t>((batch + 1) * rows / num_threads);
|
||||
|
||||
Comparator comparer(input_data);
|
||||
|
||||
for (int64_t i = start_row; i < end_row; ++i) {
|
||||
auto row_offset = i * cols;
|
||||
for (int64_t j = 0; j < block_slice; ++j) {
|
||||
int64_t cur_idx = row_offset + j;
|
||||
|
||||
const auto* cur_value = input_data + cur_idx; // using pointer to data is faster than input_data[cur_idx]
|
||||
auto best = *cur_value; // save best value so we only have one load in the CompareValueOnly call
|
||||
int64_t top_idx = cur_idx;
|
||||
|
||||
for (int64_t l = 1; l < num_blocks; ++l) {
|
||||
cur_value += block_slice;
|
||||
if (comparer.CompareValueOnly(*cur_value, best)) {
|
||||
best = *cur_value;
|
||||
top_idx = cur_value - input_data;
|
||||
}
|
||||
}
|
||||
|
||||
values_map(i, j) = best;
|
||||
// convert overall index to result index
|
||||
// avoid '/' if possible for perf reasons
|
||||
indices_map(i, j) = block_slice == 1 ? (top_idx - row_offset - j)
|
||||
: (top_idx - row_offset - j) / block_slice;
|
||||
}
|
||||
}
|
||||
// Extract these k elements and place them in the results placeholder
|
||||
for (int64_t l = 0; l < k; ++l) {
|
||||
const auto& elem = heap.top();
|
||||
auto col_index = (k - l - 1) * block_slice + j;
|
||||
values_map(i, col_index) = elem.first;
|
||||
indices_map(i, col_index) = elem.second;
|
||||
heap.pop();
|
||||
};
|
||||
} else if (use_priority_queue) {
|
||||
find_top_k =
|
||||
[num_threads, rows, block_slice, num_blocks, k, sorted,
|
||||
input_data, cols, &values_map, &indices_map](std::ptrdiff_t batch) {
|
||||
int64_t start_row = static_cast<int64_t>(batch * rows / num_threads);
|
||||
int64_t end_row = static_cast<int64_t>((batch + 1) * rows / num_threads);
|
||||
|
||||
Comparator comparer(input_data);
|
||||
|
||||
// the heap is stored in indices_data. each iteration overwrites the old data when it adds the
|
||||
// initial k values, so we don't need to clear it.
|
||||
std::vector<int64_t> indices_data(k);
|
||||
int64_t* indices = indices_data.data(); // raw pointer is slightly faster for HeapifyIthPosition
|
||||
|
||||
for (int64_t i = start_row; i < end_row; ++i) {
|
||||
const auto row_offset = i * cols;
|
||||
|
||||
for (int64_t j = 0; j < block_slice; ++j) {
|
||||
int64_t l = 0;
|
||||
auto cur_idx = row_offset + j;
|
||||
|
||||
// add first k items starting from the bottom up
|
||||
for (; l < k; ++l) {
|
||||
indices[k - l - 1] = cur_idx;
|
||||
HeapifyIthPosition(indices, k - l - 1, k, comparer);
|
||||
|
||||
cur_idx += block_slice;
|
||||
}
|
||||
|
||||
// insert remainder if the next value would replace the top of the heap (current worst top k value)
|
||||
// save top so we only have one load in the CompareValueOnly call
|
||||
auto top = input_data[indices[0]];
|
||||
for (; l < num_blocks; ++l) {
|
||||
// we can compare value only. if the current value is equal to the top of the heap it won't
|
||||
// replace it as the index will be higher.
|
||||
if (comparer.CompareValueOnly(input_data[cur_idx], top)) {
|
||||
indices[0] = cur_idx;
|
||||
HeapifyIthPosition(indices, 0, k, comparer);
|
||||
top = input_data[indices[0]];
|
||||
}
|
||||
|
||||
cur_idx += block_slice;
|
||||
}
|
||||
|
||||
if (sorted) {
|
||||
// Extract these k elements and place them in the results placeholder
|
||||
for (l = 0; l < k; ++l) {
|
||||
auto idx = indices[0];
|
||||
auto col_index = (k - l - 1) * block_slice + j;
|
||||
values_map(i, col_index) = input_data[idx];
|
||||
// convert overall index to result index. avoid '/' if possible for perf reasons
|
||||
indices_map(i, col_index) = block_slice == 1 ? (idx - row_offset - j)
|
||||
: (idx - row_offset - j) / block_slice;
|
||||
|
||||
// put the last value at the top of the heap to replace the removed one, and push it into
|
||||
// place in a heap one smaller.
|
||||
indices[0] = indices[k - l - 1];
|
||||
HeapifyIthPosition(indices, 0, k - l - 1, comparer);
|
||||
}
|
||||
} else {
|
||||
for (l = 0; l < k; ++l) {
|
||||
int64_t idx = indices[l];
|
||||
auto col_index = l * block_slice + j;
|
||||
values_map(i, col_index) = input_data[idx];
|
||||
// convert overall index to result index. avoid '/' if possible for perf reasons
|
||||
indices_map(i, col_index) = block_slice == 1 ? (idx - row_offset - j)
|
||||
: (idx - row_offset - j) / block_slice;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else { // sorted == false
|
||||
// The optimizer will clean-up the redundant condition based on the template parameter 'sorted'
|
||||
};
|
||||
} else {
|
||||
find_top_k =
|
||||
[num_threads, rows, block_slice, num_blocks, k, sorted,
|
||||
input_data, cols,
|
||||
&values_map, &indices_map](std::ptrdiff_t batch) {
|
||||
int64_t start_row = static_cast<int64_t>(batch * rows / num_threads);
|
||||
int64_t end_row = static_cast<int64_t>((batch + 1) * rows / num_threads);
|
||||
|
||||
// If the top K values are not required to be sorted, we use a more optimal selection algorithm
|
||||
// Average - O(n). Worst - O(n * ln(n)) or O(n^2) depending on the implementation, where 'n' is the number of input
|
||||
Comparator comparer(input_data);
|
||||
|
||||
const auto& data_holder = select_top_k<Comparator>(input_map, i, num_blocks, block_slice, j, k, false);
|
||||
// we re-use a single data_holder for performance. avoids allocating memory on each iteration.
|
||||
// the call to SelectTopK overwrites any existing data so we don't need to clear on each iteration.
|
||||
std::vector<int64_t> data_holder(num_blocks);
|
||||
|
||||
// Insert the top 'k' (largest or smallest) elements into the final output buffers
|
||||
for (int64_t l = 0; l < k; ++l) {
|
||||
const auto& elem = data_holder[l];
|
||||
auto col_index = l * block_slice + j;
|
||||
values_map(i, col_index) = elem.first;
|
||||
indices_map(i, col_index) = elem.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int64_t i = start_row; i < end_row; ++i) {
|
||||
auto row_offset = i * cols;
|
||||
for (int64_t j = 0; j < block_slice; ++j) {
|
||||
SelectTopK<Comparator>(comparer, row_offset, num_blocks, block_slice, j, k, sorted, data_holder);
|
||||
|
||||
// Insert the top 'k' (largest or smallest) elements into the final output buffers
|
||||
for (int64_t l = 0; l < k; ++l) {
|
||||
int64_t idx = data_holder[l];
|
||||
auto col_index = l * block_slice + j;
|
||||
values_map(i, col_index) = input_data[idx];
|
||||
// convert overall index to result index. avoid the cost of the '/' is possible
|
||||
indices_map(i, col_index) = block_slice == 1 ? (idx - row_offset - j)
|
||||
: (idx - row_offset - j) / block_slice;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
if (num_threads <= 1) {
|
||||
find_top_k(0);
|
||||
} else {
|
||||
// we want to re-use the storage variables in each lambda as much as possible to minimize allocations
|
||||
// on each iteration, so the lambda does multiple rows. e.g. the data_holder and indices_data vectors.
|
||||
// the alternative would be to use TryBatchParallelFor with the lambda doing one row.
|
||||
threadpool->SimpleParallelFor(num_threads, find_top_k);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -204,22 +361,14 @@ static Status TopKImpl(OpKernelContext* p_op_kernel_context, const Tensor* input
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
if (sorted && largest) {
|
||||
// extract sorted largest TopK elements
|
||||
extract_top_k_elements<true, true, GreaterValueCmp<T>>(input, input_shape, values, indices, output_shape, k,
|
||||
gsl::narrow_cast<unsigned>(axis_parsed));
|
||||
} else if (sorted && !largest) {
|
||||
// extract sorted smallest TopK elements
|
||||
extract_top_k_elements<false, true, LesserValueCmp<T>>(input, input_shape, values, indices, output_shape, k,
|
||||
gsl::narrow_cast<unsigned>(axis_parsed));
|
||||
} else if (largest) {
|
||||
// extract unsorted (order undefined) largest TopK elements
|
||||
extract_top_k_elements<true, false, GreaterValueCmp<T>>(input, input_shape, values, indices, output_shape, k,
|
||||
gsl::narrow_cast<unsigned>(axis_parsed));
|
||||
auto* threadpool = p_op_kernel_context->GetOperatorThreadPool();
|
||||
|
||||
if (largest) {
|
||||
FindTopKElements<GreaterValueCmp<T>>(input, input_shape, values, indices, output_shape, k, sorted,
|
||||
gsl::narrow_cast<unsigned>(axis_parsed), threadpool);
|
||||
} else {
|
||||
// extract unsorted (order undefined) smallest TopK elements
|
||||
extract_top_k_elements<false, false, LesserValueCmp<T>>(input, input_shape, values, indices, output_shape, k,
|
||||
gsl::narrow_cast<unsigned>(axis_parsed));
|
||||
FindTopKElements<LesserValueCmp<T>>(input, input_shape, values, indices, output_shape, k, sorted,
|
||||
gsl::narrow_cast<unsigned>(axis_parsed), threadpool);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -425,54 +425,17 @@ TEST(TopKOperator, Top1ExplicitAxisMultiDInputSmallestElements) {
|
|||
top_1_explicit_axis_MultiD_input_smallest(11, 0); //unsorted
|
||||
}
|
||||
|
||||
TEST(TopKOperator, SelectFirstSortNext) {
|
||||
// in this test, we will select the top 5 elements first then sort the chosen 5 elements
|
||||
// Select + Sort = O(n + k * ln(k)) = 50 + 5 * ln(5) = 58.047
|
||||
// Sorted selection: O(n * ln(k)) = 50 * ln(5) = 80.47
|
||||
// The algorithm used will be Select + Sort
|
||||
std::vector<float> input_vals = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0,
|
||||
11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0,
|
||||
21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0,
|
||||
31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0,
|
||||
41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0};
|
||||
std::vector<int64_t> input_dimensions = {50};
|
||||
std::vector<float> expected_vals = {50.0f, 49.0f, 48.0f, 47.0f, 46.0f};
|
||||
std::vector<int64_t> expected_indices = {49, 48, 47, 46, 45};
|
||||
std::vector<int64_t> expected_dimensions = {5};
|
||||
int64_t axis = 0;
|
||||
RunTest(11, 5, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, axis); // largest values
|
||||
}
|
||||
|
||||
TEST(TopKOperator, SelectFirstSortNextInt64) {
|
||||
// in this test, we will select the top 5 elements first then sort the chosen 5 elements
|
||||
// Select + Sort = O(n + k * ln(k)) = 50 + 5 * ln(5) = 58.047
|
||||
// Sorted selection: O(n * ln(k)) = 50 * ln(5) = 80.47
|
||||
// The algorithm used will be Select + Sort
|
||||
std::vector<int64_t> input_vals = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
|
||||
11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
|
||||
21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
|
||||
31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
|
||||
41, 42, 43, 44, 45, 46, 47, 48, 49, 50};
|
||||
std::vector<int64_t> input_dimensions = {50};
|
||||
std::vector<int64_t> expected_vals = {50, 49, 48, 47, 46};
|
||||
std::vector<int64_t> expected_indices = {49, 48, 47, 46, 45};
|
||||
std::vector<int64_t> expected_dimensions = {5};
|
||||
int64_t axis = 0;
|
||||
RunTest(11, 5, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, axis); // largest values
|
||||
}
|
||||
|
||||
TEST(TopKOperator, SortedSelection) {
|
||||
// in this test, we will use sorted selection (using heap)
|
||||
// Select + Sort = O(n + k * ln(k)) = 10 + 5 * ln(5) = 18.04
|
||||
// Sorted selection: O(n * ln(k)) = 10 * ln(5) = 16.09
|
||||
// The algorithm used will be Sorted selection
|
||||
std::vector<float> input_vals = {10.0f, 8.0f, 7.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 9.0f, 3.0};
|
||||
std::vector<int64_t> input_dimensions = {10};
|
||||
std::vector<float> expected_vals = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
|
||||
std::vector<int64_t> expected_indices = {6, 7, 9, 3, 4};
|
||||
std::vector<int64_t> expected_dimensions = {5};
|
||||
int64_t axis = 0;
|
||||
RunTest(11, 5, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, axis, 0); // smallest values
|
||||
// test path where SelectTopK is used (select using std::nth_element)
|
||||
// we use a custom path for n=1, and priority queue based implementation if
|
||||
// bool use_priority_queue = k != 1 && (k < 4 || (std::log2(k) / std::log2(n)) < 0.725);
|
||||
// so easiest way to test is for k to be 4 and n to be a little larger
|
||||
TEST(TopKOperator, NthElement) {
|
||||
std::vector<float> input_vals = {10.0f, 8.0f, 7.0f, 4.0f, 5.0f, 6.0f};
|
||||
std::vector<int64_t> input_dimensions = {6};
|
||||
std::vector<float> expected_vals = {10.0f, 8.0f, 7.0f, 6.0f};
|
||||
std::vector<int64_t> expected_indices = {0, 1, 2, 5};
|
||||
std::vector<int64_t> expected_dimensions = {4};
|
||||
RunTest(11, 4, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false);
|
||||
}
|
||||
|
||||
// test dimension in range (GridDim::maxThreadsPerBlock, GridDim::maxThreadsPerBlock * 2], ie. [257, 512]
|
||||
|
|
@ -532,5 +495,66 @@ TEST(TopKOperator, BigArrayBigTopKSorted) {
|
|||
RunTest(11, 9000, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, 0, 1, 1);
|
||||
}
|
||||
|
||||
static void top_3_all_same(int opset_version, int64_t largest = 1) {
|
||||
// whether it's largest or smallest we should pick the first instance/s of a number if there are multiple
|
||||
std::vector<float> input_vals = {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f};
|
||||
std::vector<int64_t> input_dimensions = {2, 4};
|
||||
std::vector<float> expected_vals = {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f};
|
||||
std::vector<int64_t> expected_indices = {0, 1, 2, 0, 1, 2};
|
||||
std::vector<int64_t> expected_dimensions = {2, 3};
|
||||
RunTest(opset_version, 3, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false, -1, largest);
|
||||
}
|
||||
|
||||
TEST(TopKOperator, Top3AllSame) {
|
||||
top_3_all_same(10);
|
||||
top_3_all_same(11);
|
||||
top_3_all_same(10, 0); // smallest
|
||||
top_3_explicit_axis(11, 0);
|
||||
}
|
||||
|
||||
static void TestThreaded(int64_t k, int64_t n, int64_t batch_size) {
|
||||
std::vector<float> input_vals(n * batch_size, 0.0f);
|
||||
std::iota(input_vals.begin(), input_vals.end(), 0.0f);
|
||||
|
||||
std::vector<int64_t> input_dimensions = {n, batch_size};
|
||||
|
||||
std::vector<float> expected_vals(n * k, 0.0f);
|
||||
std::vector<int64_t> expected_indices(n * k, 0);
|
||||
std::vector<int64_t> expected_dimensions = {n, k};
|
||||
|
||||
for (int64_t i = 0; i < n; ++i) {
|
||||
auto begin_batch_output = expected_vals.begin() + i * k;
|
||||
std::iota(begin_batch_output, begin_batch_output + k, static_cast<float>(((i + 1) * batch_size) - k));
|
||||
std::reverse(begin_batch_output, begin_batch_output + k);
|
||||
|
||||
// indices are within the axis so don't need adjusting by the batch number
|
||||
auto begin_indices_output = expected_indices.begin() + i * k;
|
||||
std::iota(begin_indices_output, begin_indices_output + k, batch_size - k);
|
||||
std::reverse(begin_indices_output, begin_indices_output + k);
|
||||
}
|
||||
|
||||
RunTest(11, k, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false);
|
||||
}
|
||||
|
||||
// create input of 2x1000 and select 200 so 2 threads are needed based on there being 2 rows
|
||||
// and sufficient items to process given this calculation:
|
||||
// int64_t threads_needed = static_cast<int64_t>(std::floor(input_shape.Size() * k / (128 * 1024)));
|
||||
TEST(TopKOperator, PriorityQueueThreaded) {
|
||||
const int64_t k = 200;
|
||||
const int64_t n = 2;
|
||||
const int64_t batch_size = 1000;
|
||||
TestThreaded(k, n, batch_size);
|
||||
}
|
||||
|
||||
// create input of 2x500 and select 400 so 2 threads are needed based on there being 2 rows
|
||||
// and sufficient items to process given this calculation:
|
||||
// int64_t threads_needed = static_cast<int64_t>(std::floor(input_shape.Size() * k / (128 * 1024)));
|
||||
TEST(TopKOperator, SelectTopKThreaded) {
|
||||
const int64_t k = 400;
|
||||
const int64_t n = 2;
|
||||
const int64_t batch_size = 500;
|
||||
TestThreaded(k, n, batch_size);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue